From 49f694b1f784c286e6a9e96b8409ebeaf719f02d Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Sun, 14 Apr 2024 22:48:17 -0500 Subject: [PATCH 01/12] Add main eagle script modeling_eagle.py Signed-off-by: Siddhi Velankar --- .../transformers/modeling/modeling_eagle.py | 2120 +++++++++++++++++ 1 file changed, 2120 insertions(+) create mode 100644 intel_extension_for_transformers/transformers/modeling/modeling_eagle.py diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_eagle.py b/intel_extension_for_transformers/transformers/modeling/modeling_eagle.py new file mode 100644 index 00000000000..3cd486bdbac --- /dev/null +++ b/intel_extension_for_transformers/transformers/modeling/modeling_eagle.py @@ -0,0 +1,2120 @@ +import os +import random +import copy +import json +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig +import math +from typing import List, Optional, Tuple, Union +import torch.nn.functional as F +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from huggingface_hub import hf_hub_download +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) +from transformers.cache_utils import Cache, DynamicCache + +# Define the tree and chain structures for the model. +tree_structure = [[0], [1], [2], [3], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [2, 0], [2, 1], [3, 0] + , [0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [0, 1, 1], [0, 2, 0], [0, 2, 1], [1, 0, 0], + [0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]] + +chain_structure = [[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0]] + +# The model structure of EAGLE is largely based on a single Decoder layer from LLaMA, with the model definition essentially copied from LLaMA. + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + This function generates a mask that prevents future tokens from attending to earlier tokens in the sequence. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + This function is used to adjust the mask dimensions to match the expected input shape for attention mechanisms. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Repeats the key-value pairs in the hidden states tensor. + This function is used to adjust the dimensions of the hidden states tensor to match the expected input shape for attention mechanisms. + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """ + Rotates half the hidden dimensions of the input tensor. + This function is used to apply rotary positional embeddings, which help in modeling the relative positions of tokens in the sequence. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + """ + Applies rotary positional embeddings to the query and key tensors. + This function is used to incorporate positional information into the attention mechanism, enhancing the model's ability to understand the order of tokens in the sequence. + """ + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class EAGLERotaryEmbedding(torch.nn.Module): + """ + EAGLERotaryEmbedding class implements rotary embeddings for positional encoding in transformer models. + It calculates and caches the cosine and sine values for the embeddings to speed up computation. + """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + """ + Initializes the EAGLERotaryEmbedding with the given dimensions, maximum position embeddings, base, and device. + """ + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Pre-calculate and cache cosine and sine values for the embeddings. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + """ + Sets up the cosine and sine caches for the given sequence length, device, and data type. + This method is called during initialization and when the sequence length exceeds the cached length. + """ + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + """ + Forward pass of the EAGLERotaryEmbedding. + If the sequence length exceeds the cached length, it updates the cosine and sine caches. + Returns the cosine and sine values for the embeddings. + """ + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class EAGLELinearScalingRotaryEmbedding(EAGLERotaryEmbedding): + """ + EAGLELinearScalingRotaryEmbedding extends EAGLERotaryEmbedding with linear scaling. + It adjusts the sequence length by a scaling factor before calculating the cosine and sine values. + LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev + """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + """ + Initializes the EAGLELinearScalingRotaryEmbedding with the given dimensions, maximum position embeddings, base, device, and scaling factor. + """ + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + """ + Overrides the _set_cos_sin_cache method to apply linear scaling to the sequence length before calculating the cosine and sine values. + """ + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class EAGLEDynamicNTKScalingRotaryEmbedding(EAGLERotaryEmbedding): + """ + EAGLEDynamicNTKScalingRotaryEmbedding extends EAGLERotaryEmbedding with dynamic NTK scaling. + It adjusts the base and sequence length dynamically based on the scaling factor and sequence length. + """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + """ + Initializes the EAGLEDynamicNTKScalingRotaryEmbedding with the given dimensions, maximum position embeddings, base, device, and scaling factor. + """ + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + """ + Overrides the _set_cos_sin_cache method to apply dynamic NTK scaling to the sequence length before calculating the cosine and sine values. + """ + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) +class EAGLE_Config(PretrainedConfig): + r""" + Copyed from LlamaConfig, the structure of EAGLE consists of a single LlamaDecoder layer. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + pretraining_tp (`int`, *optional*, defaults to `1`): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + + Example: + + """ + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + **kwargs, + ): + """ + Initializes the model with the given configuration parameters. + + Parameters: + - vocab_size: The size of the vocabulary. Default is 32000. + - hidden_size: The size of the hidden layers in the model. Default is 4096. + - intermediate_size: The size of the intermediate layers in the model. Default is 11008. + - num_hidden_layers: The number of hidden layers in the model. Default is 32. + - num_attention_heads: The number of attention heads in the model. Default is 32. + - num_key_value_heads: The number of key-value heads in the model. Default is None, which sets it to the same as num_attention_heads. + - hidden_act: The activation function to use in the hidden layers. Default is "silu". + - max_position_embeddings: The maximum number of position embeddings. Default is 2048. + - initializer_range: The range for the initializer. Default is 0.02. + - rms_norm_eps: The epsilon value for RMS normalization. Default is 1e-6. + - use_cache: Whether to use cache for the model. Default is True. + - pad_token_id: The ID of the padding token. Default is None. + - bos_token_id: The ID of the beginning-of-sentence token. Default is 1. + - eos_token_id: The ID of the end-of-sentence token. Default is 2. + - pretraining_tp: The pretraining type. Default is 1. + - tie_word_embeddings: Whether to tie the word embeddings. Default is False. + - rope_scaling: The configuration for ROPE scaling. Default is None. + - **kwargs: Additional keyword arguments. + """ + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # For backward compatibility, set num_key_value_heads to num_attention_heads if not provided. + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validates the `rope_scaling` configuration. + + Raises: + - ValueError: If `rope_scaling` is not None and does not meet the expected format or values. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + +class EAGLEAttention(nn.Module): + """ + EAGLEAttention implements multi-headed attention mechanism inspired by the 'Attention Is All You Need' paper. + It includes support for rotary positional embeddings (RoPE) with linear or dynamic scaling. + """ + + def __init__(self, config): + """ + Initializes the EAGLEAttention module with the given configuration. + + Parameters: + - config: A configuration object containing model parameters. + """ + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + """ + Initializes the rotary positional embeddings (RoPE) based on the configuration. + """ + if self.config.rope_scaling is None: + self.rotary_emb = EAGLERotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = EAGLELinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "dynamic": + self.rotary_emb = EAGLEDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + """ + Reshapes the input tensor to the required shape for multi-head attention. + + Parameters: + - tensor: The input tensor. + - seq_len: The sequence length. + - bsz: The batch size. + + Returns: + - The reshaped tensor. + """ + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass of the EAGLEAttention module. + + Parameters: + - hidden_states: The input hidden states. + - attention_mask: The attention mask tensor. + - position_ids: The position IDs tensor. + - past_key_value: The past key-value states. + - output_attentions: Whether to output attention weights. + - use_cache: Whether to use cache for the model. + + Returns: + - The output tensor. + - The attention weights (if output_attentions is True). + - The past key-value states (if use_cache is True). + """ + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + # Splitting the projection weights for pretraining + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # Repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class EAGLEMLP(nn.Module): + """ + EAGLEMLP implements the MLP (Multi-Layer Perceptron) component of the EAGLE model. + It includes support for pretraining with multiple slices. + """ + + def __init__(self, config): + """ + Initializes the EAGLEMLP module with the given configuration. + + Parameters: + - config: A configuration object containing model parameters. + """ + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + +class EAGLERMSNorm(nn.Module): + """ + EAGLERMSNorm implements Root Mean Square Layer Normalization (RMSNorm) for stabilizing the training of deep neural networks. + It normalizes the input tensor by scaling it with a learned weight and adjusting its variance. + """ + + def __init__(self, hidden_size, eps=1e-6): + """ + Initializes the EAGLERMSNorm module with the given hidden size and epsilon for numerical stability. + + Parameters: + - hidden_size: The size of the hidden layer. + - eps: A small value added to the variance for numerical stability. Default is 1e-6. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """ + Forward pass of the EAGLERMSNorm module. + + Parameters: + - hidden_states: The input tensor to be normalized. + + Returns: + - The normalized tensor. + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class EAGLEDecoderLayer(nn.Module): + """ + EAGLEDecoderLayer represents a single layer of the EAGLE decoder, which includes self-attention and a multi-layer perceptron (MLP). + It also optionally includes layer normalization before the self-attention mechanism. + """ + + def __init__(self, config, index): + """ + Initializes the EAGLEDecoderLayer with the given configuration and index. + + Parameters: + - config: A configuration object containing model parameters. + - index: The index of the layer within the decoder. + """ + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = EAGLEAttention(config=config) + self.mlp = EAGLEMLP(config) + self.index = index + if self.index != 0: + self.input_layernorm = EAGLERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = EAGLERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + + Forward pass of the EAGLEDecoderLayer. + + Parameters: + - hidden_states: The input tensor of shape `(batch, seq_len, embed_dim)`. + - attention_mask: An optional tensor of shape `(batch, 1, tgt_len, src_len)` indicating padding elements. + - position_ids: An optional tensor of shape `(batch, seq_len)` indicating the position of each token in the sequence. + - past_key_value: An optional tuple of tensors containing past key and value states for caching. + - output_attentions: A boolean indicating whether to return the attention weights. + - use_cache: A boolean indicating whether to use caching for faster decoding. + + Returns: + - The output tensor of shape `(batch, seq_len, embed_dim)`. + - Optionally, the attention weights and past key-value states. + """ + residual = hidden_states + + if self.index != 0: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class node: + """ + Node represents a node in a tree structure, with support for tracking parent-child relationships and depth. + It also allows for storing a value and a dictionary key associated with the node. + """ + + def __init__(self, parent=None, value=None, dict_key=None): + """ + Initializes a Node with an optional parent, value, and dictionary key. + + Parameters: + - parent: The parent node of this node. Default is None. + - value: The value associated with this node. Default is None. + - dict_key: The dictionary key associated with this node. Default is None. + """ + self.parent = parent + self.value = value + if parent: + self.depth = parent.depth + 1 + parent.children.append(self) + else: + self.depth = 0 + self.children = [] + self.dict_key = dict_key + + def is_leaf(self): + """ + Checks if the node is a leaf node (i.e., has no children). + + Returns: + - True if the node is a leaf, False otherwise. + """ + return len(self.children) == 0 + + def all_index(self): + """ + Returns the index path from the root to this node. + + Returns: + - A list of indices representing the path from the root to this node. + """ + if not self.parent.parent: + return [self.index] + else: + return self.parent.all_index() + [self.index] + +class Tree: + """ + Tree represents a tree structure, with support for tracking parent-child relationships, depth, and indexing. + It is initialized with a list of tree nodes, each represented as a list of values. + """ + + def __init__(self, tree_list): + """ + Initializes the Tree with the given list of tree nodes. + + Parameters: + - tree_list: A list of tree nodes, each represented as a list of values. + """ + sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x)) + self.root = node() + self.node_dic = {} + for tree_node in sorted_tree_list: + cur_value = tree_node[-1] + if len(tree_node) == 1: + cur_node = node(parent=self.root, value=cur_value, dict_key=tuple(tree_node)) + else: + cur_parent = self.node_dic[tuple(tree_node[:-1])] + cur_node = node(parent=cur_parent, value=cur_value, dict_key=tuple(tree_node)) + self.node_dic[tuple(tree_node)] = cur_node + self.indexnode() + + def max_depth(self): + """ + Returns the maximum depth of the tree. + + Returns: + - The maximum depth of the tree. + """ + return max([item.depth for item in self.node_dic.values()]) + + def num_node_wchild(self): + """ + Returns the number of nodes in the tree that have children. + + Returns: + - The number of nodes with children. + """ + num_c = 0 + for item in self.node_dic.values(): + if not item.is_leaf(): + num_c += 1 + return num_c + + def get_node_wchild(self): + """ + Returns a list of nodes in the tree that have children. + + Returns: + - A list of nodes with children. + """ + ns = [] + for item in self.node_dic.values(): + if not item.is_leaf(): + ns.append(item) + return ns + + def indexnode(self): + """ + Assigns an index to each non-leaf node in the tree. + """ + cur_index = 0 + for key in self.node_dic: + cur_node = self.node_dic[key] + if not cur_node.is_leaf(): + cur_node.index = cur_index + cur_index += 1 + + +def generate_tree_buffers_for_eagle(tree_choices, device="cuda"): + """ + Generates tree buffers for the EAGLE model based on the given tree choices. + + Parameters: + - tree_choices: A list of tree nodes, each represented as a list of values. + - device: The device to which the tensors should be moved. Default is "cuda". + + Returns: + - A dictionary containing the tree buffers. + """ + TOPK = 5 + tree = Tree(tree_choices) + tree_len = tree.num_node_wchild() + + max_depth = tree.max_depth() + nodes_wc = tree.get_node_wchild() + + depth_counts = [0 for _ in range(max_depth - 1)] + for x in nodes_wc: + depth_counts[x.depth - 1] += 1 + depth_counts_sum = [sum(depth_counts[:i + 1]) for i in range(len(depth_counts))] + + tree_attn_mask = torch.eye(tree_len, tree_len) + + for id, x in enumerate(nodes_wc): + tree_attn_mask[id, x.all_index()] = 1 + + tree_attn_mask_list0 = [tree_attn_mask[:ml, :ml] for ml in depth_counts_sum] + tree_attn_mask_list = [] + for id, x in enumerate(tree_attn_mask_list0): + x = x[-depth_counts[id]:] + tree_attn_mask_list.append(x) + + tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] + repeat_nums = [[] for _ in depth_counts] + start = 0 + bias = 0 + for i in range(len(depth_counts)): + bias = 0 + repeat_j = 0 + for j in range(depth_counts[i]): + cur_node = nodes_wc[start + j] + cur_parent = cur_node.parent + + if j != 0: + if cur_parent != parent: + bias += 1 + parent = cur_parent + repeat_nums[i].append(j - repeat_j) + repeat_j = j + else: + parent = cur_parent + tree_indices_list[i][j] = cur_node.value + TOPK * (bias) + repeat_nums[i].append(j - repeat_j + 1) + start += depth_counts[i] + + position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] + + tree_buffers = { + "attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list], + "tree_indices": tree_indices_list, + "position_ids": position_ids, + "repeat_nums": repeat_nums + } + + # Move the tensors in the dictionary to the specified device + tree_buffers = { + k: [i.clone().to(device) for i in v] + if isinstance(v[0], torch.Tensor) + else ( + torch.tensor(v, device=device) + if isinstance(v, torch.Tensor) + else v + ) + for k, v in tree_buffers.items() + } + return tree_buffers + +class EAGLEModel(nn.Module): + """ + EAGLEModel is a custom PyTorch model designed for the EAGLE architecture, which includes a series of decoder layers and a final linear layer. + It also supports gradient checkpointing and initialization of tree buffers for the EAGLE model. + """ + + def __init__(self, config, bias=True): + """ + Initializes the EAGLEModel with the given configuration and an optional bias parameter for the final linear layer. + + Parameters: + - config: A configuration object containing model parameters. + - bias: A boolean indicating whether to include bias in the final linear layer. Default is True. + """ + super().__init__() + self.gradient_checkpointing = False + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([EAGLEDecoderLayer(config, index) for index in range(config.num_hidden_layers)]) + self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=bias) + self.act = ACT2FN[config.hidden_act] + + def init_tree(self): + """ + Initializes the tree buffers for the EAGLE model. + """ + self.tree_buffer = generate_tree_buffers_for_eagle(self.tree, self.embed_tokens.weight.device) + + def reset(self): + """ + Resets the tree mask to None. + """ + self.tree_mask = None + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + """ + Prepares the decoder attention mask by creating a causal mask and combining it with the provided attention mask. + + Parameters: + - attention_mask: The provided attention mask. + - input_shape: The shape of the input tensor. + - inputs_embeds: The input embeddings. + - past_key_values_length: The length of past key values. + + Returns: + - The combined attention mask. + """ + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + torch.float32, # [MODIFIED] force to cast to float32 + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + # [MODIFIED] add tree mask + if hasattr(self, "tree_mask") and self.tree_mask is not None: + tree_mask = self.tree_mask + tree_len = tree_mask.size(-1) + bs = combined_attention_mask.size(0) + combined_attention_mask[:, :, -tree_len:, -tree_len:][ + tree_mask.repeat(bs, 1, 1, 1) == 0 + ] = torch.finfo(torch.float32).min + + return combined_attention_mask + + def forward( + self, + hidden_states, + input_ids, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + """ + Forward pass of the EAGLEModel. + + Parameters: + - hidden_states: The input hidden states. + - input_ids: The input token IDs. + - attention_mask: An optional attention mask. + - position_ids: An optional tensor of position IDs. + - past_key_values: An optional list of past key values. + - use_cache: A boolean indicating whether to use caching. + - output_attentions: A boolean indicating whether to output attention weights. + - output_hidden_states: A boolean indicating whether to output hidden states. + + Returns: + - The output hidden states and optionally the next decoder cache and hidden states. + """ + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + with torch.no_grad(): + inputs_embeds = self.embed_tokens(input_ids) + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: + device = hidden_states.device if hidden_states is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) + + all_hidden_states = () if output_hidden_states else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if use_cache: + return hidden_states, next_decoder_cache + + return hidden_states + + def reset_kv(self): + """ + Resets the stable key-value pairs to None. + """ + self.stable_kv = None + + @torch.no_grad() + def repeat_hidden(self, hidden_state, repeat_num): + """ + Repeats the hidden state according to the given repeat numbers. + + Parameters: + - hidden_state: The input hidden state. + - repeat_num: A list of repeat numbers. + + Returns: + - The repeated hidden state. + """ + new_hidden = [] + for id, i in enumerate(repeat_num): + new_hidden.append(hidden_state[:, id:id + 1].repeat(1, i, 1)) + return torch.cat(new_hidden, dim=1) + + + def sample(self, logits, logits_processor, k=1): + """ + Samples from the logits using the provided logits processor. + + Parameters: + - logits: The input logits. + - logits_processor: A function to process the logits. + - k: The number of samples to generate. Default is 1. + + Returns: + - The sampled indices, probabilities, and optionally the operation. + """ + bs, seq_len, _ = logits.shape + logits = logits.view(-1, logits.shape[-1]) + logits = logits_processor(None, logits) + probabilities = torch.nn.functional.softmax(logits, dim=-1) + sampled_indices = torch.multinomial(probabilities, k, replacement=False) + sampled_probs = torch.gather(probabilities, -1, sampled_indices) + cumulative_sum = torch.cumsum(sampled_probs, dim=-1) + cumulative_sum = torch.cat( + (torch.zeros(cumulative_sum.shape[0], 1, device=cumulative_sum.device), cumulative_sum[:, :-1]), dim=-1) + sampled_probs = sampled_probs / (1 - cumulative_sum) + sampled_probs[torch.isinf(sampled_probs)] = -1 + sampled_probs[torch.isnan(sampled_probs)] = -1 + sampled_probs = torch.clamp(sampled_probs, min=0.0, max=1.0) + sampled_indices = sampled_indices.view(bs, seq_len, -1) + sampled_probs = sampled_probs.view(bs, seq_len, -1) + probabilities = probabilities.view(bs, seq_len, -1) + + return sampled_indices, sampled_probs, probabilities + + @torch.no_grad() + def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_length=4, use_cache=True, + attention_mask=None, len_posi=None, ): + """ + Generates sequences using top-k sampling. + + Parameters: + - hidden_states: The input hidden states. + - input_ids: The input token IDs. + - head: The head layer for processing the hidden states. + - logits_processor: A function to process the logits. + - max_length: The maximum length of the generated sequence. Default is 4. + - use_cache: A boolean indicating whether to use caching. Default is True. + - attention_mask: An optional attention mask. + - len_posi: The length of the position IDs. + + Returns: + - The generated token indices, probabilities, and optionally the operation. + """ + top_k = 5 + bs = input_ids.shape[0] + input_ids = input_ids[:, 1:] + input_ids = input_ids.to(hidden_states.device) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.to(self.device) + zero_num = position_ids.shape[1] - position_ids.max(dim=-1).values - 1 + zero_num = zero_num[:, None] + ss_token, ss_prob, ss_op = [], [], [] + if len_posi is None: + len_posi = input_ids.shape[1] + self.reset() + if use_cache: + if hasattr(self, "stable_kv") and self.stable_kv is not None: + kv_len = self.stable_kv[0][0].shape[2] + position_ids = position_ids[:, kv_len:] + out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, past_key_values=self.stable_kv, + use_cache=True, attention_mask=attention_mask, + position_ids=position_ids) + else: + out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True, + attention_mask=attention_mask, position_ids=position_ids) + self.stable_kv = past_key_values + last_nopadding = position_ids.argmax(dim=-1) + ab = tuple(range(bs)) + last_hidden = out_hidden[ab, last_nopadding][:, None] + if not self.diff_device: + last_headout = head(last_hidden) + else: + if hasattr(self, "layer_device"): + last_headout = head(last_hidden) + last_headout = last_headout.to(self.layer_device) + else: + last_headout = F.linear(last_hidden, self.headweight) + + for i in range(len(self.tree_buffer['tree_indices'])): + if logits_processor is not None: + topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) + else: + topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, + top_k, + dim=-1).values + op = None + + ss_token.append(topk_index) + ss_prob.append(topk_prob) + ss_op.append(op) + + input_ids = topk_index.view(bs, -1)[:, self.tree_buffer['tree_indices'][i]] + + attention_mask = torch.cat((attention_mask, torch.ones_like(input_ids, device=attention_mask.device, + dtype=attention_mask.dtype)), dim=1) + + if i == 0: + hidden_states = last_hidden + else: + hidden_states = out_hidden + hidden_states = self.repeat_hidden(hidden_states, self.tree_buffer["repeat_nums"][i]) + self.tree_mask = self.tree_buffer['attn_mask'][i] + position_ids = len_posi + self.tree_buffer["position_ids"][i][None, :] - zero_num + out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, past_key_values=past_key_values, + position_ids=position_ids, use_cache=True, + attention_mask=attention_mask) + len_posi += 1 + + if not self.diff_device: + last_headout = head(out_hidden) + else: + if hasattr(self, "layer_device"): + last_headout = head(out_hidden) + last_headout = last_headout.to(self.layer_device) + else: + last_headout = F.linear(out_hidden[0], self.headweight) + + if logits_processor is not None: + topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) + else: + topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, top_k, + dim=-1).values + op = None + ss_token.append(topk_index) + ss_prob.append(topk_prob) + ss_op.append(op) + + else: + # TODO + pass + + return (torch.cat(ss_token, dim=1), torch.cat(ss_prob, dim=1), ss_op) + +def prepare_logits_processor( + temperature=0.0, repetition_penalty=0.0, top_p=0.0, top_k=0 +) -> LogitsProcessorList: + """ + Prepares a list of logits processors based on the provided parameters. + + Parameters: + - temperature: The temperature to apply to the logits before applying softmax. + - repetition_penalty: The penalty for repeating tokens. + - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. + - top_k: The number of highest probability vocabulary tokens to keep for top-k filtering. + + Returns: + - A LogitsProcessorList containing the specified logits processors. + """ + processor_list = LogitsProcessorList() + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +def pad_path(path, length, pad_value=-2): + """ + Pad the given path list with a specific value up to a specified length. + + Parameters: + - path (list): The original list that needs padding. + - length (int): The desired length of the padded list. + - pad_value (optional, default=-2): The value to use for padding. + + Returns: + - list: A new list based on the original path but padded to the desired length. + + Example: + >>> pad_path([1,2,3], 5) + [1, 2, 3, -2, -2] + + Note: + If the given path is already longer than the specified length, + then no padding occurs, and the original path is returned. + """ + + # Calculate the number of padding values needed by subtracting the length + # of the path from the desired length. + # Append the padding values to the original path and return the new list. + return path + [pad_value] * (length - len(path)) + + +def generate_tree_buffers(tree_choices, device="cuda"): + """ + Generates tree buffers for a given set of tree choices. + + Parameters: + - tree_choices (list): A list of tree choices. + - device (str, optional): The device to use for tensor operations. Default is "cuda". + + Returns: + - A dictionary containing the generated tree buffers. + """ + TOPK = 5 + sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x)) + tree_len = len(sorted_tree_choices) + 1 + + # Initialize depth_counts to keep track of how many choices have a particular depth + depth_counts = [] + prev_depth = 0 + for path in sorted_tree_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + + tree_attn_mask = torch.eye(tree_len, tree_len) + tree_attn_mask[:, 0] = 1 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + # retrieve ancestor position + if len(cur_tree_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_tree_choice) - 1): + ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + tree_attn_mask[j + start + 1, ancestor_idx] = 1 + start += depth_counts[i] + + tree_indices = torch.zeros(tree_len, dtype=torch.long) + p_indices = [0 for _ in range(tree_len - 1)] + b_indices = [[] for _ in range(tree_len - 1)] + tree_indices[0] = 0 + start = 0 + bias = 0 + for i in range(len(depth_counts)): + inlayer_bias = 0 + b = [] + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + cur_parent = cur_tree_choice[:-1] + if j != 0: + if cur_parent != parent: + bias += 1 + inlayer_bias += 1 + parent = cur_parent + b = [] + else: + parent = cur_parent + tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i + bias) + 1 + p_indices[start + j] = inlayer_bias + if len(b) > 0: + b_indices[start + j] = copy.deepcopy(b) + else: + b_indices[start + j] = [] + b.append(cur_tree_choice[-1] + TOPK * (i + bias) + 1) + start += depth_counts[i] + + p_indices = [-1] + p_indices + tree_position_ids = torch.zeros(tree_len, dtype=torch.long) + start = 0 + for i in range(len(depth_counts)): + tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 + start += depth_counts[i] + + retrieve_indices_nest = [] + retrieve_paths = [] + for i in range(len(sorted_tree_choices)): + cur_tree_choice = sorted_tree_choices[-i - 1] + retrieve_indice = [] + if cur_tree_choice in retrieve_paths: + continue + else: + for c in range(len(cur_tree_choice)): + retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1])) + retrieve_paths.append(cur_tree_choice[:c + 1]) + retrieve_indices_nest.append(retrieve_indice) + max_length = max([len(x) for x in retrieve_indices_nest]) + retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + retrieve_indices = retrieve_indices + 1 + retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], + dim=1) + + maxitem = retrieve_indices.max().item() + 5 + + def custom_sort(lst): + # sort_keys=[len(list)] + sort_keys = [] + for i in range(len(lst)): + sort_keys.append(lst[i] if lst[i] >= 0 else maxitem) + return sort_keys + + retrieve_indices = retrieve_indices.tolist() + retrieve_indices = sorted(retrieve_indices, key=custom_sort) + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + + p_indices = torch.tensor(p_indices) + p_indices_new = p_indices[retrieve_indices] + p_indices_new = p_indices_new.tolist() + + b_indices = [[]] + b_indices + b_indices_new = [] + for ib in range(retrieve_indices.shape[0]): + iblist = [] + for jb in range(retrieve_indices.shape[1]): + index = retrieve_indices[ib, jb] + if index == -1: + iblist.append([]) + else: + b = b_indices[index] + if len(b) > 0: + bt = [] + for bi in b: + bt.append(torch.where(tree_indices == bi)[0].item()) + iblist.append(torch.tensor(bt, device=device)) + else: + iblist.append(b) + b_indices_new.append(iblist) + + # Aggregate the generated buffers into a dictionary + tree_buffers = { + "tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0), + "tree_indices": tree_indices, + "tree_position_ids": tree_position_ids, + "retrieve_indices": retrieve_indices, + } + + # Move the tensors in the dictionary to the specified device + tree_buffers = { + k: v.clone().to(device) + if isinstance(v, torch.Tensor) + else torch.tensor(v, device=device) + for k, v in tree_buffers.items() + } + tree_buffers["p_indices"] = p_indices_new + tree_buffers["b_indices"] = b_indices_new + return tree_buffers + +def _prepare_decoder_attention_mask( + attention_mask, tree_mask, input_shape, inputs_embeds, past_key_values_length +): + """ + Prepares the decoder attention mask by combining causal masks and attention masks. + + Parameters: + - attention_mask (torch.Tensor): The attention mask tensor. + - tree_mask (torch.Tensor): The tree mask tensor. + - input_shape (tuple): The shape of the input tensor. + - inputs_embeds (torch.Tensor): The input embeddings tensor. + - past_key_values_length (int): The length of past key values. + + Returns: + - combined_attention_mask (torch.Tensor): The combined attention mask tensor. + """ + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + torch.float32, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + if tree_mask is not None: + tree_len = tree_mask.size(-1) + bs = combined_attention_mask.size(0) + combined_attention_mask[:, :, -tree_len:, -tree_len:][ + tree_mask.repeat(bs, 1, 1, 1) == 0 + ] = combined_attention_mask.min() + + return combined_attention_mask + +@torch.no_grad() +def forward_with_tree_mask( + model, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + tree_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, # [MODIFIED] past_key_value is KVCache class + inputs_embeds: Optional[torch.FloatTensor] = None, +): + """ + Forward pass through the model with tree mask applied. + + Parameters: + - model (torch.nn.Module): The model to forward through. + - input_ids (torch.LongTensor): The input IDs tensor. + - attention_mask (torch.Tensor, optional): The attention mask tensor. + - tree_mask (torch.Tensor, optional): The tree mask tensor. + - position_ids (torch.LongTensor, optional): The position IDs tensor. + - past_key_values (KVCache, optional): The past key values. + - inputs_embeds (torch.FloatTensor, optional): The input embeddings tensor. + + Returns: + - hidden_states (torch.Tensor): The output hidden states tensor. + - next_cache (Cache): The next cache. + """ + output_attentions = False + use_cache = True + batch_size, seq_length = input_ids.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + inputs_embeds = model.embed_tokens(input_ids) + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = _prepare_decoder_attention_mask( + attention_mask, + tree_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + next_decoder_cache = () + + for idx, decoder_layer in enumerate(model.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + hidden_states = model.norm(hidden_states) + + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + return hidden_states, next_cache + +def initialize_tree(input_ids, model, logits_processor, attention_mask=None): + """ + Initializes the tree by forwarding through the model with a tree mask applied. + + Parameters: + - input_ids (torch.LongTensor): The input IDs tensor. + - model (torch.nn.Module): The model to forward through. + - logits_processor (LogitsProcessorList, optional): An instance of LogitsProcessorList used to modify the prediction scores. + - attention_mask (torch.Tensor, optional): The attention mask tensor. + + Returns: + - tree_logits (torch.Tensor): The tree logits tensor. + - logits (torch.Tensor): The logits tensor. + - hidden_states (torch.Tensor): The hidden states tensor. + - token (torch.Tensor): The token tensor. + - past_key_value (KVCache): The past key values. + """ + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + hidden_states, past_key_value = forward_with_tree_mask(model.base_model.model, input_ids=input_ids, + attention_mask=attention_mask, position_ids=position_ids) + logits = model.base_model.lm_head(hidden_states) + + if logits_processor is not None: + sample_logits = logits[:, -1] + sample_logits = logits_processor(None, sample_logits) + probabilities = torch.nn.functional.softmax(sample_logits, dim=-1) + token = torch.multinomial(probabilities, 1) + else: + token = torch.argmax(logits[:, -1], dim=-1) + token = token[:, None] + input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) + + tree_logits = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head, logits_processor, + attention_mask=attention_mask) + + return tree_logits, logits, hidden_states, token, past_key_value + +def generate_candidates(tree_logits, tree_indices, retrieve_indices, sample_token, logits_processor): + """ + Generates candidates for the tree. + + Parameters: + - tree_logits (torch.Tensor): The tree logits tensor. + - tree_indices (torch.Tensor): The tree indices tensor. + - retrieve_indices (torch.Tensor): The retrieve indices tensor. + - sample_token (torch.Tensor): The sample token tensor. + - logits_processor (LogitsProcessorList, optional): An instance of LogitsProcessorList used to modify the prediction scores. + + Returns: + - cart_candidates (torch.Tensor): The cartesian product candidates tensor. + - cart_candidates_prob (torch.Tensor, optional): The cartesian product candidates probabilities tensor. + - tree_candidates (torch.Tensor): The tree candidates tensor. + """ + bs = sample_token.shape[0] + sample_token = sample_token.to(tree_indices.device) + + candidates_logit = sample_token + candidates_tree_logits = tree_logits[0] + + candidates = torch.cat([candidates_logit, candidates_tree_logits.view(bs, -1)], dim=-1) + + tree_candidates = candidates[:, tree_indices] + + tree_candidates_ext = torch.cat( + [tree_candidates, torch.zeros((bs, 1), dtype=torch.long, device=tree_candidates.device)-1], dim=-1) + + cart_candidates = tree_candidates_ext[:, retrieve_indices] + + if logits_processor is not None: + candidates_tree_prob = tree_logits[1] + candidates_prob = torch.cat( + [torch.ones((bs, 1), device=candidates_tree_prob.device, dtype=torch.float32), + candidates_tree_prob.view(bs, -1)], + dim=-1) + + tree_candidates_prob = candidates_prob[:, tree_indices] + tree_candidates_prob_ext = torch.cat( + [tree_candidates_prob, torch.ones((bs, 1), dtype=torch.float32, device=tree_candidates_prob.device)], + dim=-1) + cart_candidates_prob = tree_candidates_prob_ext[:, retrieve_indices] + else: + cart_candidates_prob = None + + return cart_candidates, cart_candidates_prob, tree_candidates + +def tree_decoding( + model, + tree_candidates, + past_key_values, + tree_position_ids, + input_ids, + retrieve_indices, + attention_mask=None, + tree_mask=None, +): + + zero_num = attention_mask.shape[1]-attention_mask.long().sum(-1) + zero_num = zero_num[:, None] + position_ids = tree_position_ids[None,:] + input_ids.shape[1]-zero_num + + + attention_mask = torch.cat( + (attention_mask, torch.ones_like(tree_candidates, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1) + + hidden_states, past_key_value = forward_with_tree_mask(model.base_model.model, input_ids=tree_candidates,past_key_values=past_key_values, + attention_mask=attention_mask, tree_mask=tree_mask,position_ids=position_ids) + + tree_logits = model.base_model.lm_head(hidden_states) + + + + + logits = tree_logits[:, retrieve_indices] + return logits, hidden_states,past_key_value + + +def evaluate_posterior( + logits, candidates, logits_processor, cart_candidates_prob, op, p_indices, tree_candidates, b_indices, + finish_flag +): + """ + Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. + + Depending on the temperature value, the function either uses greedy decoding or evaluates posterior + probabilities to select the best candidate. + + Args: + - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). + - candidates (torch.Tensor): Candidate token sequences. + - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. + - posterior_threshold (float): Threshold for posterior probability. + - posterior_alpha (float): Scaling factor for the threshold. + + Returns: + - best_candidate (torch.Tensor): Index of the chosen best candidate. + - accept_length (int): Length of the accepted candidate sequence. + """ + # Greedy decoding based on temperature value + if logits_processor is None: + bs = tree_candidates.size(0) + # Find the tokens that match the maximum logits for each position in the sequence + posterior_mask = ( + candidates[:, :, 1:].to(logits.device) == torch.argmax(logits[:, :, :-1], dim=-1) + ).int() + candidates_accept_length = (torch.cumprod(posterior_mask, dim=-1)).sum(dim=-1) + accept_length = candidates_accept_length.max(dim=1).values + best_candidate = torch.argmax(candidates_accept_length, dim=-1).to(torch.long) + + + bt = tuple(range(bs)) + logits_batch = logits[bt, best_candidate, accept_length, :] + accept_length = accept_length.tolist() + + for batch in range(bs): + if finish_flag[batch]: + accept_length[batch] = 0 + + return best_candidate.tolist(), accept_length, logits_batch + + else: + cart_candidates_prob = cart_candidates_prob.to(logits.device) + bs = cart_candidates_prob.size(0) + + logits = logits_processor(None, logits) + probs = torch.softmax(logits, dim=-1) + + best_candidate_list = [] + accept_length_list = [] + sample_p_list = [] + + for batch in range(bs): + accept_length = 1 + accept_cand = candidates[batch, 0, :1] + best_candidate = 0 + for i in range(1, candidates.shape[2]): + if i != accept_length: + break + adjustflag = False + is_eq = (candidates[batch, :, :accept_length] == accept_cand).all(dim=1) + fi = torch.nonzero(is_eq, as_tuple=True)[0][0] + gtp = probs[batch, fi, i - 1] + candidates_set = [] + for j in range(candidates.shape[1]): + if is_eq[j]: + x = candidates[batch, j, i] + xi = x.item() + if xi in candidates_set or xi == -1: + continue + candidates_set.append(xi) + r = random.random() + px = gtp[xi] + qx = cart_candidates_prob[batch, j, i] + if qx <= 0: + continue + acp = px / qx + if r <= acp: + accept_cand = torch.cat((accept_cand, x[None]), dim=0) + accept_length += 1 + best_candidate = j + break + else: + q = op[i - 1][batch][p_indices[j][i]].clone() + b = b_indices[j][i] + if len(b) > 0: + mask = tree_candidates[batch][b] + q[mask] = 0 + q = q / q.sum() + gtp = gtp - q + gtp[gtp < 0] = 0 + gtp = gtp / gtp.sum() + adjustflag = True + if adjustflag and accept_length != candidates.shape[1]: + sample_p = gtp + else: + sample_p = probs[batch, best_candidate, accept_length - 1] + best_candidate_list.append(best_candidate) + accept_length_list.append(accept_length - 1) + sample_p_list.append(sample_p) + + for batch in range(bs): + if finish_flag[batch]: + accept_length_list[batch] = 0 + + return best_candidate_list, accept_length_list, sample_p_list + +@torch.no_grad() +def update_inference_inputs( + input_ids, + attention_mask, + candidates, + best_candidate, + accept_length, + retrieve_indices, + logits_processor, + new_token, + past_key_values, + model, + hidden_state_new, + sample_p, + finish_flag +): + """ + Updates inference inputs based on the best candidate selected. + + Parameters: + - input_ids (torch.Tensor): The input IDs tensor. + - attention_mask (torch.Tensor): The attention mask tensor. + - candidates (torch.Tensor): The candidates tensor. + - best_candidate (torch.Tensor): The best candidate tensor. + - accept_length (list): The list of accepted lengths. + - retrieve_indices (torch.Tensor): The retrieve indices tensor. + - logits_processor (LogitsProcessorList, optional): An instance of LogitsProcessorList used to modify the prediction scores. + - new_token (torch.Tensor): The new token tensor. + - past_key_values (KVCache): The past key values. + - model (torch.nn.Module): The model to forward through. + - hidden_state_new (torch.Tensor): The new hidden state tensor. + - sample_p (torch.Tensor): The sample probability tensor. + - finish_flag (list): The list indicating whether each batch has finished. + + Returns: + - input_ids (torch.Tensor): The updated input IDs tensor. + - tree_logits (torch.Tensor): The tree logits tensor. + - new_token (torch.Tensor): The updated new token tensor. + - None: Placeholder for a return value. + - token (torch.Tensor): The token tensor. + - attention_mask (torch.Tensor): The updated attention mask tensor. + - finish_flag (list): The updated finish flag list. + - new_outs (list): The list of new outputs. + - new_kv (tuple): The updated past key values. + """ + new_outs = [] + finish_flag = copy.deepcopy(finish_flag) + bs = len(best_candidate) + prev_input_len = input_ids.shape[1] + max_acccept_len = max(accept_length) + + retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices[0]] + + ab = tuple(range(bs)) + select_indices = ( + retrieve_indices.cpu()[ab, best_candidate, : max_acccept_len + 1, ...] + prev_input_len + ) + new_input_ids = candidates[ab, best_candidate, : max_acccept_len + 1, ...] + + draft_hidden = retrieve_hidden_state_new[ab, best_candidate, :max_acccept_len + 1] + + new_attention_mask = torch.zeros((bs, max_acccept_len + 1), dtype=torch.long) + + for batch in range(bs): + new_attention_mask[batch, :accept_length[batch] + 1] = 1 + new_o = new_input_ids[batch, : accept_length[batch] + 1].tolist() + new_outs.append(new_o) + if model.base_model.config.eos_token_id in new_o: + finish_flag[batch] = True + new_token[batch] += accept_length[batch] + 1 + + attention_mask = torch.cat((attention_mask, new_attention_mask.to(attention_mask.device)), dim=1) + + batch_dim_indices = torch.tensor(ab)[:, None].expand(-1, max_acccept_len + 1) + + new_kv = () + + for past_key_values_data in past_key_values: + layer_kv = () + for korv in past_key_values_data: + tgt = korv[batch_dim_indices, :, select_indices, :] + tgt = tgt.permute(0, 2, 1, 3) + dst = korv[:, :, prev_input_len: prev_input_len + tgt.shape[-2], :] + dst.copy_(tgt, non_blocking=True) + layer_kv += (korv[:, :, : prev_input_len + tgt.shape[-2], :],) + new_kv += (layer_kv,) + + input_ids = torch.cat((input_ids, new_input_ids.to(input_ids.device)), dim=1) + + prob = sample_p + if isinstance(prob, list): + prob = torch.stack(prob) + if logits_processor is not None: + token = torch.multinomial(prob, 1) + else: + token = torch.argmax(prob, dim=-1) + token = token[:, None] + + draft_input_ids = torch.cat((new_input_ids, torch.ones(bs, 1, dtype=torch.long, device=new_input_ids.device)), dim=1) + token_ = token[:, 0] + + draft_input_ids[ab, torch.tensor(accept_length, dtype=torch.long) + 1] = token_ + + tree_logits = model.ea_layer.topK_genrate(draft_hidden, + input_ids=draft_input_ids, + head=model.base_model.lm_head, logits_processor=logits_processor, attention_mask=attention_mask, len_posi=input_ids.shape[1]) + + return input_ids, tree_logits, new_token, None, token, attention_mask, finish_flag, new_outs, new_kv + + +class EAGLE: + """ + EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty + + This class integrates a Huggingface causal LLM with the EAGLE model for enhanced text generation capabilities. + + Attributes: + - base_model: The Huggingface causal LLM. + - eagle_path: The path to the EAGLE model. + - use_tree_attn: Whether to use tree attention. Recommended for strong computational power and small batch sizes. + - ea_layer: The EAGLE model layer. + - tree: The tree structure used for attention. + + The following models can be directly accelerated using their corresponding checkpoints. + vicuna-7b-v1.3: yuhuili/EAGLE-Vicuna-7B-v1.3 + vicuna-13b-v1.3: yuhuili/EAGLE-Vicuna-13B-v1.3 + vicuna-33b-v1.3: yuhuili/EAGLE-Vicuna-33B-v1.3 + LLaMA2-Chat-7B: yuhuili/EAGLE-llama2-chat-7B + LLaMA2-Chat-13B: yuhuili/EAGLE-llama2-chat-13B + LLaMA2-Chat-70B: yuhuili/EAGLE-llama2-chat-70B + Mixtral-8x7B-Instruct-v0.1: yuhuili/EAGLE-mixtral-instruct-8x7B + Other models need to be trained independently. + Please refer to https://github.com/SafeAILab/EAGLE for more information. + + """ + def __init__(self, base_model, eagle_path, use_tree_attn=True): + """ + Initializes the EAGLE class with the base model and EAGLE model path. + + Args: + - base_model: Huggingface causal LLM. + - eagle_path: Path of EAGLE. + - use_tree_attn: Whether to use tree attention. Default is True. + """ + self.base_model = base_model + + configpath = os.path.join(eagle_path, "config.json") + if not os.path.exists(configpath): + configpath = hf_hub_download(eagle_path, "config.json") + + load_model_path = os.path.join(eagle_path, "pytorch_model.bin") + if not os.path.exists(load_model_path): + load_model_path = hf_hub_download(eagle_path, "pytorch_model.bin") + + config = EAGLE_Config.from_pretrained(configpath) + with open(configpath, "r") as f: + con = json.loads(f.read()) + try: + bias = con["bias"] + except: + bias = True + + self.ea_layer = EAGLEModel(config, bias=bias) + + ea_layer_state_dict = torch.load(load_model_path, map_location="cpu") + self.ea_layer.load_state_dict(ea_layer_state_dict, strict=True) + device = base_model.model.layers[-1].self_attn.q_proj.weight.device + self.ea_layer.to(self.base_model.dtype).to(device) + self.ea_layer.device = device + + if device != base_model.lm_head.weight.device: + self.ea_layer.diff_device = True + self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device) + else: + self.ea_layer.diff_device = False + + if use_tree_attn: + self.ea_layer.tree = tree_structure + self.tree = tree_structure + else: + self.ea_layer.tree = chain_structure + self.tree = chain_structure + + self.ea_layer.init_tree() + self.base_model.eval() + self.ea_layer.eval() + + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + temperature=0.0, + top_p=0.0, + top_k=0, + max_new_tokens=512, + max_length=2048, + ) -> torch.LongTensor: + """ + Generates text using the EAGLE model. + + Args: + - input_ids: The input IDs tensor. + - attention_mask: The attention mask tensor. + - temperature: The temperature for sampling. Default is 0.0. + - top_p: The top-p value for nucleus sampling. Default is 0.0. + - top_k: The top-k value for top-k sampling. Default is 0. + - max_new_tokens: The maximum number of new tokens to generate. Default is 512. + - max_length: The maximum length of the generated text. Default is 2048. + + Returns: + - torch.LongTensor: The generated text as a tensor of token IDs. + """ + tree_choices = self.tree + bs = input_ids.shape[0] + if temperature > 1e-5: + logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) + else: + logits_processor = None + + input_ids = input_ids.clone() + self.ea_layer.reset_kv() + + if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: + tree_buffers = self.tree_buffers + else: + tree_buffers = generate_tree_buffers( + tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device + ) + tree_buffers["retrieve_indices_head"] = tree_buffers["retrieve_indices"].to( + self.base_model.lm_head.weight.device) + tree_buffers["tree_position_ids"] = tree_buffers["tree_position_ids"].to(self.base_model.device) + self.tree_buffers = tree_buffers + self.tree_choices = tree_choices + + tree_buffers["retrieve_indices_batch"] = tree_buffers["retrieve_indices"].expand(bs, -1, -1) + + input_ids=input_ids.to(tree_buffers["tree_position_ids"].device) + attention_mask=attention_mask.to(tree_buffers["tree_position_ids"].device) + + bool_mask = attention_mask.bool() + + out_inputids = [ids.tolist() for ids, mask in zip(input_ids, bool_mask)] + + tree_logits, logits, hidden_state, sample_token,past_key_values = initialize_tree( + input_ids, self, logits_processor, + attention_mask=attention_mask + ) + + new_token = [0]*bs + finish_flag=[False]*bs + + for idx in range(max_length): + candidates, cart_candidates_prob, tree_candidates = generate_candidates( + tree_logits, + tree_buffers["tree_indices"], + tree_buffers["retrieve_indices"], + sample_token, + logits_processor + ) + logits, hidden_state_new,past_key_values = tree_decoding( + self, + tree_candidates, + past_key_values, + tree_buffers["tree_position_ids"], + input_ids, + tree_buffers["retrieve_indices_head"], + attention_mask=attention_mask, + tree_mask=tree_buffers["tree_attn_mask"] + ) + best_candidate, accept_length, sample_p = evaluate_posterior( + logits, candidates, logits_processor, cart_candidates_prob, tree_logits[2], tree_buffers["p_indices"], + tree_candidates, tree_buffers["b_indices"],finish_flag + ) + input_ids, tree_logits, new_token, hidden_state, sample_token, attention_mask, newfinish_flag, new_outs,past_key_values = update_inference_inputs( + input_ids, + attention_mask, + candidates, + best_candidate, + accept_length, + tree_buffers["retrieve_indices_batch"], + logits_processor, + new_token, + past_key_values, + self, + hidden_state_new, + sample_p, + finish_flag + ) + + min_uf_newtokens = max_length + 10 + for batch in range(bs): + if not finish_flag[batch]: + out_inputids[batch].extend(new_outs[batch]) + min_uf_newtokens = min(min_uf_newtokens, new_token[batch]) + finish_flag = newfinish_flag + + if min(finish_flag): + break + if min_uf_newtokens > max_new_tokens: + break + if input_ids.shape[1] + 10 + len(tree_choices) > max_length: + break + + if len(out_inputids)==1: + return out_inputids[0] + return out_inputids From 38083cc1a558c80b23c18dcce1c071ae23a0aca3 Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Sun, 14 Apr 2024 23:01:38 -0500 Subject: [PATCH 02/12] add example file Signed-off-by: Siddhi Velankar --- .../text-generation/eagle/eagle_example.py | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 examples/huggingface/pytorch/text-generation/eagle/eagle_example.py diff --git a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py new file mode 100644 index 00000000000..72b4bd31562 --- /dev/null +++ b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py @@ -0,0 +1,133 @@ +# Import necessary libraries +import os +import time +import argparse +from tqdm import tqdm +from pprint import pprint +from modeling_eagle import EAGLE +from transformers import AutoModelForCausalLM,LlamaModel,MistralModel,AutoTokenizer,MixtralForCausalLM,LlamaConfig,LlamaForCausalLM +import torch +from fastchat.model import get_conversation_template + +# Define a dictionary to map string representations of data types to their corresponding PyTorch data types +DTYPE_MAP = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +# Function to parse command-line arguments +def parse_argument(): + # Construct the argument parser and add the arguments + arg_parser = argparse.ArgumentParser(description="Test EAGLE with Llama-2") + arg_parser.add_argument("-d", '--device', type=str, default="cpu", choices=["cpu", "xpu", "cuda"], + help="Target device for text generation") + arg_parser.add_argument("-t", '--dtype', type=str, default="float32", choices=["float32", "float16", "bfloat16"], + help="Data type for text generation") + arg_parser.add_argument('--max_new_tokens', type=int, default=512, + help="Number of max new tokens for text generation.") + arg_parser.add_argument('--use_eagle', action='store_true', help='Use EAGLE model for generation') + return arg_parser.parse_args() + + +if __name__ == "__main__": + # Define paths to the base and eagle models + # Note - You can use your choice of base model, we are using llama-2-7b-chat in this example + base_model_path = "meta-llama/Llama-2-7b-chat-hf" + ea_model_path = "yuhuili/EAGLE-llama2-chat-7B" + + # Parse command-line arguments + args = parse_argument() + + # Extract arguments + device = args.device + dtype = DTYPE_MAP[args.dtype] + max_new_tokens = args.max_new_tokens + use_eagle = args.use_eagle + + if device == "xpu": + import intel_extension_for_pytorch as ipex + + if device == "cpu": + assert dtype == torch.float32, f"CPU can only support float32. Got dtype = {args.dtype}" + + tokenizer=AutoTokenizer.from_pretrained(base_model_path) + + # Define a message, conversation template and system message to be processed by the model + your_message="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions." + conv = get_conversation_template("llama-2-chat") + sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." + conv.system_message = sys_p + conv.append_message(conv.roles[0], your_message) + conv.append_message(conv.roles[1], None) + prompt1 = conv.get_prompt()+" " + + your_message="Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point." + conv = get_conversation_template("llama-2-chat") + sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." + conv.system_message = sys_p + conv.append_message(conv.roles[0], your_message) + conv.append_message(conv.roles[1], None) + prompt2 = conv.get_prompt()+" " + + your_message = "Write a brief technical documentation for a new software feature, explaining its functionality, benefits, and how to implement it." + conv = get_conversation_template("llama-2-chat") + sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." + conv.system_message = sys_p + conv.append_message(conv.roles[0], your_message) + conv.append_message(conv.roles[1], None) + prompt3 = conv.get_prompt()+" " + + your_message = "Craft a marketing campaign plan for a new product launch, including the target audience, key messages, and promotional strategies." + conv = get_conversation_template("llama-2-chat") + sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." + conv.system_message = sys_p + conv.append_message(conv.roles[0], your_message) + conv.append_message(conv.roles[1], None) + prompt4 = conv.get_prompt()+" " + + # Pick the prompt you want to use to run the example or write your own prompt, initialize it as a list + text = [prompt1] + + # Load the model and set parameters + model = AutoModelForCausalLM.from_pretrained(base_model_path, + torch_dtype=dtype, + ).eval().to(device) + # model = torch.xpu.optimize(model) + model.config.pad_token_id = model.config.eos_token_id + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + + eagle = EAGLE(model, ea_model_path) + + # Iterate over each prompt in the text list + for prompt in text: + # Initialize counters + t_total = 0. + total_new_tokens = 0 + for _ in tqdm(range(10)): + t_b = time.time() + + # inputs = tokenizer(text, return_tensors="pt", padding=True).to(device) + inputs = tokenizer(text, return_tensors="pt").to(device) + + # Generate text using the EAGLE model if the flag is set, otherwise use the base model + if use_eagle: + outs = eagle.generate(**inputs, max_new_tokens=max_new_tokens,temperature=0.5) + else: + outs=model.generate(**inputs, max_new_tokens=max_new_tokens,temperature=0.5) + output=tokenizer.batch_decode(outs) + t_e = time.time() + + + # Update counters + t_total += t_e - t_b + total_new_tokens += len(outs) - inputs.input_ids.shape[-1] + + # Print total new tokens and generated text + print("Total new tokens", total_new_tokens) + pprint(output) + print(f"TPS: {total_new_tokens / t_total}") + + del model + From 1cdbf5bb32e486fbd6cd106cedc38a186b202d8e Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Sun, 14 Apr 2024 23:38:12 -0500 Subject: [PATCH 03/12] Add README Signed-off-by: Siddhi Velankar --- .../pytorch/text-generation/eagle/README.md | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 examples/huggingface/pytorch/text-generation/eagle/README.md diff --git a/examples/huggingface/pytorch/text-generation/eagle/README.md b/examples/huggingface/pytorch/text-generation/eagle/README.md new file mode 100644 index 00000000000..c5db3d15d91 --- /dev/null +++ b/examples/huggingface/pytorch/text-generation/eagle/README.md @@ -0,0 +1,47 @@ +# Eagle - Speculative Sampling + +Intel Extension for Transformers supports the EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) which is a speculative sampling method that improves text generation speed. + +See [here](https://arxiv.org/abs/2401.15077) to view the paper and [here](https://github.com/SafeAILab/EAGLE) for more info on EAGLE code . + + +## Setup and installation + +With pip: + +```bash +pip install eagle-llm +``` + +From the source: + +```bash +git clone https://github.com/SafeAILab/EAGLE.git +cd EAGLE +pip install -e . +``` +The modeling_eagle.py script can be found at intel_extension_for_transformers/transformers/modeling/modeling_eagle.py + +## Usage + +The script accepts several command-line arguments: + +-d or --device: Target device for text generation (default: "cpu", options: "cpu", "xpu", "cuda"). +-t or --dtype: Data type for text generation (default: "float32", options: "float32", "float16", "bfloat16"). +--max_new_tokens: Number of max new tokens for text generation (default: 512). +--use_eagle: Use EAGLE model for generation (default: False). + +```bash +python3 eagle_example.py +python main.py -d xpu --max_new_tokens 1024 --use_eagle + +``` + +The default base model is set to "meta-llama/Llama-2-7b-chat-hf", you can change it in the script by reassigning the variable "base_model_path" to model of your choice + + +# Features +- Dynamic Device and Data Type Configuration: The script allows specifying the target device and data type for text generation, supporting CPU, XPU, and CUDA devices, as well as different data types. +- EAGLE Model Integration: The script demonstrates how to integrate the EAGLE model for enhanced text generation capabilities. +- Performance Measurement: The script measures the performance of the text generation process, including the tokens per second (TPS) based on the total new tokens and total time. + From 5722da0ed79197d40ed0b1753160382b0fab7e2c Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Sun, 14 Apr 2024 23:39:03 -0500 Subject: [PATCH 04/12] Update README Signed-off-by: Siddhi Velankar --- .../pytorch/text-generation/eagle/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/README.md b/examples/huggingface/pytorch/text-generation/eagle/README.md index c5db3d15d91..3d716e96a4d 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/README.md +++ b/examples/huggingface/pytorch/text-generation/eagle/README.md @@ -2,7 +2,7 @@ Intel Extension for Transformers supports the EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) which is a speculative sampling method that improves text generation speed. -See [here](https://arxiv.org/abs/2401.15077) to view the paper and [here](https://github.com/SafeAILab/EAGLE) for more info on EAGLE code . +See [here](https://arxiv.org/abs/2401.15077) to view the paper and [here](https://github.com/SafeAILab/EAGLE) for more info on EAGLE code. ## Setup and installation @@ -26,10 +26,10 @@ The modeling_eagle.py script can be found at intel_extension_for_transformers/tr The script accepts several command-line arguments: --d or --device: Target device for text generation (default: "cpu", options: "cpu", "xpu", "cuda"). --t or --dtype: Data type for text generation (default: "float32", options: "float32", "float16", "bfloat16"). ---max_new_tokens: Number of max new tokens for text generation (default: 512). ---use_eagle: Use EAGLE model for generation (default: False). +- -d or --device: Target device for text generation (default: "cpu", options: "cpu", "xpu", "cuda"). +- -t or --dtype: Data type for text generation (default: "float32", options: "float32", "float16", "bfloat16"). +- --max_new_tokens: Number of max new tokens for text generation (default: 512). +- --use_eagle: Use EAGLE model for generation (default: False). ```bash python3 eagle_example.py From c04831a072d7686bdb0274746a573da7f02dd386 Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Mon, 22 Apr 2024 23:58:53 -0500 Subject: [PATCH 05/12] update import statement Signed-off-by: Siddhi Velankar --- .../huggingface/pytorch/text-generation/eagle/eagle_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py index 72b4bd31562..1d1a4890af6 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py +++ b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py @@ -4,7 +4,7 @@ import argparse from tqdm import tqdm from pprint import pprint -from modeling_eagle import EAGLE +from eagle.modeling_eagle import EAGLE from transformers import AutoModelForCausalLM,LlamaModel,MistralModel,AutoTokenizer,MixtralForCausalLM,LlamaConfig,LlamaForCausalLM import torch from fastchat.model import get_conversation_template From 27360c2a0841174733a093dda44dab0e32536a4a Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Tue, 23 Apr 2024 00:02:59 -0500 Subject: [PATCH 06/12] tps logic update based on outs type Signed-off-by: Siddhi Velankar --- .../pytorch/text-generation/eagle/eagle_example.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py index 1d1a4890af6..c41434ad12a 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py +++ b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py @@ -119,13 +119,15 @@ def parse_argument(): output=tokenizer.batch_decode(outs) t_e = time.time() - # Update counters t_total += t_e - t_b - total_new_tokens += len(outs) - inputs.input_ids.shape[-1] - - # Print total new tokens and generated text - print("Total new tokens", total_new_tokens) + + # Update tokens per second based on outs type + if isinstance(outs, list): + total_new_tokens = len(outs) + else: + total_new_tokens += len(outs[0]) - inputs.input_ids.shape[-1] + # Print output pprint(output) print(f"TPS: {total_new_tokens / t_total}") From 4d4516b12f8c271825c6622a689a1f22f6e52104 Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Tue, 23 Apr 2024 00:03:34 -0500 Subject: [PATCH 07/12] Delete intel_extension_for_transformers/transformers/modeling/modeling_eagle.py Signed-off-by: Siddhi Velankar --- .../transformers/modeling/modeling_eagle.py | 2120 ----------------- 1 file changed, 2120 deletions(-) delete mode 100644 intel_extension_for_transformers/transformers/modeling/modeling_eagle.py diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_eagle.py b/intel_extension_for_transformers/transformers/modeling/modeling_eagle.py deleted file mode 100644 index 3cd486bdbac..00000000000 --- a/intel_extension_for_transformers/transformers/modeling/modeling_eagle.py +++ /dev/null @@ -1,2120 +0,0 @@ -import os -import random -import copy -import json -import torch -from torch import nn -from transformers.configuration_utils import PretrainedConfig -import math -from typing import List, Optional, Tuple, Union -import torch.nn.functional as F -import torch.utils.checkpoint -from transformers.activations import ACT2FN -from huggingface_hub import hf_hub_download -from transformers.generation.logits_process import ( - LogitsProcessorList, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, -) -from transformers.cache_utils import Cache, DynamicCache - -# Define the tree and chain structures for the model. -tree_structure = [[0], [1], [2], [3], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [2, 0], [2, 1], [3, 0] - , [0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [0, 1, 1], [0, 2, 0], [0, 2, 1], [1, 0, 0], - [0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]] - -chain_structure = [[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0]] - -# The model structure of EAGLE is largely based on a single Decoder layer from LLaMA, with the model definition essentially copied from LLaMA. - -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - This function generates a mask that prevents future tokens from attending to earlier tokens in the sequence. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - This function is used to adjust the mask dimensions to match the expected input shape for attention mechanisms. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Repeats the key-value pairs in the hidden states tensor. - This function is used to adjust the dimensions of the hidden states tensor to match the expected input shape for attention mechanisms. - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def rotate_half(x): - """ - Rotates half the hidden dimensions of the input tensor. - This function is used to apply rotary positional embeddings, which help in modeling the relative positions of tokens in the sequence. - """ - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - """ - Applies rotary positional embeddings to the query and key tensors. - This function is used to incorporate positional information into the attention mechanism, enhancing the model's ability to understand the order of tokens in the sequence. - """ - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - -class EAGLERotaryEmbedding(torch.nn.Module): - """ - EAGLERotaryEmbedding class implements rotary embeddings for positional encoding in transformer models. - It calculates and caches the cosine and sine values for the embeddings to speed up computation. - """ - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - """ - Initializes the EAGLERotaryEmbedding with the given dimensions, maximum position embeddings, base, and device. - """ - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Pre-calculate and cache cosine and sine values for the embeddings. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - """ - Sets up the cosine and sine caches for the given sequence length, device, and data type. - This method is called during initialization and when the sequence length exceeds the cached length. - """ - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - """ - Forward pass of the EAGLERotaryEmbedding. - If the sequence length exceeds the cached length, it updates the cosine and sine caches. - Returns the cosine and sine values for the embeddings. - """ - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class EAGLELinearScalingRotaryEmbedding(EAGLERotaryEmbedding): - """ - EAGLELinearScalingRotaryEmbedding extends EAGLERotaryEmbedding with linear scaling. - It adjusts the sequence length by a scaling factor before calculating the cosine and sine values. - LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev - """ - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - """ - Initializes the EAGLELinearScalingRotaryEmbedding with the given dimensions, maximum position embeddings, base, device, and scaling factor. - """ - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - """ - Overrides the _set_cos_sin_cache method to apply linear scaling to the sequence length before calculating the cosine and sine values. - """ - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class EAGLEDynamicNTKScalingRotaryEmbedding(EAGLERotaryEmbedding): - """ - EAGLEDynamicNTKScalingRotaryEmbedding extends EAGLERotaryEmbedding with dynamic NTK scaling. - It adjusts the base and sequence length dynamically based on the scaling factor and sequence length. - """ - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - """ - Initializes the EAGLEDynamicNTKScalingRotaryEmbedding with the given dimensions, maximum position embeddings, base, device, and scaling factor. - """ - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - """ - Overrides the _set_cos_sin_cache method to apply dynamic NTK scaling to the sequence length before calculating the cosine and sine values. - """ - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) -class EAGLE_Config(PretrainedConfig): - r""" - Copyed from LlamaConfig, the structure of EAGLE consists of a single LlamaDecoder layer. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - """ - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - """ - Initializes the model with the given configuration parameters. - - Parameters: - - vocab_size: The size of the vocabulary. Default is 32000. - - hidden_size: The size of the hidden layers in the model. Default is 4096. - - intermediate_size: The size of the intermediate layers in the model. Default is 11008. - - num_hidden_layers: The number of hidden layers in the model. Default is 32. - - num_attention_heads: The number of attention heads in the model. Default is 32. - - num_key_value_heads: The number of key-value heads in the model. Default is None, which sets it to the same as num_attention_heads. - - hidden_act: The activation function to use in the hidden layers. Default is "silu". - - max_position_embeddings: The maximum number of position embeddings. Default is 2048. - - initializer_range: The range for the initializer. Default is 0.02. - - rms_norm_eps: The epsilon value for RMS normalization. Default is 1e-6. - - use_cache: Whether to use cache for the model. Default is True. - - pad_token_id: The ID of the padding token. Default is None. - - bos_token_id: The ID of the beginning-of-sentence token. Default is 1. - - eos_token_id: The ID of the end-of-sentence token. Default is 2. - - pretraining_tp: The pretraining type. Default is 1. - - tie_word_embeddings: Whether to tie the word embeddings. Default is False. - - rope_scaling: The configuration for ROPE scaling. Default is None. - - **kwargs: Additional keyword arguments. - """ - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # For backward compatibility, set num_key_value_heads to num_attention_heads if not provided. - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validates the `rope_scaling` configuration. - - Raises: - - ValueError: If `rope_scaling` is not None and does not meet the expected format or values. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") - -class EAGLEAttention(nn.Module): - """ - EAGLEAttention implements multi-headed attention mechanism inspired by the 'Attention Is All You Need' paper. - It includes support for rotary positional embeddings (RoPE) with linear or dynamic scaling. - """ - - def __init__(self, config): - """ - Initializes the EAGLEAttention module with the given configuration. - - Parameters: - - config: A configuration object containing model parameters. - """ - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - """ - Initializes the rotary positional embeddings (RoPE) based on the configuration. - """ - if self.config.rope_scaling is None: - self.rotary_emb = EAGLERotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = EAGLELinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = EAGLEDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - """ - Reshapes the input tensor to the required shape for multi-head attention. - - Parameters: - - tensor: The input tensor. - - seq_len: The sequence length. - - bsz: The batch size. - - Returns: - - The reshaped tensor. - """ - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Forward pass of the EAGLEAttention module. - - Parameters: - - hidden_states: The input hidden states. - - attention_mask: The attention mask tensor. - - position_ids: The position IDs tensor. - - past_key_value: The past key-value states. - - output_attentions: Whether to output attention weights. - - use_cache: Whether to use cache for the model. - - Returns: - - The output tensor. - - The attention weights (if output_attentions is True). - - The past key-value states (if use_cache is True). - """ - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - # Splitting the projection weights for pretraining - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # Reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # Repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class EAGLEMLP(nn.Module): - """ - EAGLEMLP implements the MLP (Multi-Layer Perceptron) component of the EAGLE model. - It includes support for pretraining with multiple slices. - """ - - def __init__(self, config): - """ - Initializes the EAGLEMLP module with the given configuration. - - Parameters: - - config: A configuration object containing model parameters. - """ - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - -class EAGLERMSNorm(nn.Module): - """ - EAGLERMSNorm implements Root Mean Square Layer Normalization (RMSNorm) for stabilizing the training of deep neural networks. - It normalizes the input tensor by scaling it with a learned weight and adjusting its variance. - """ - - def __init__(self, hidden_size, eps=1e-6): - """ - Initializes the EAGLERMSNorm module with the given hidden size and epsilon for numerical stability. - - Parameters: - - hidden_size: The size of the hidden layer. - - eps: A small value added to the variance for numerical stability. Default is 1e-6. - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - """ - Forward pass of the EAGLERMSNorm module. - - Parameters: - - hidden_states: The input tensor to be normalized. - - Returns: - - The normalized tensor. - """ - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class EAGLEDecoderLayer(nn.Module): - """ - EAGLEDecoderLayer represents a single layer of the EAGLE decoder, which includes self-attention and a multi-layer perceptron (MLP). - It also optionally includes layer normalization before the self-attention mechanism. - """ - - def __init__(self, config, index): - """ - Initializes the EAGLEDecoderLayer with the given configuration and index. - - Parameters: - - config: A configuration object containing model parameters. - - index: The index of the layer within the decoder. - """ - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = EAGLEAttention(config=config) - self.mlp = EAGLEMLP(config) - self.index = index - if self.index != 0: - self.input_layernorm = EAGLERMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = EAGLERMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - - Forward pass of the EAGLEDecoderLayer. - - Parameters: - - hidden_states: The input tensor of shape `(batch, seq_len, embed_dim)`. - - attention_mask: An optional tensor of shape `(batch, 1, tgt_len, src_len)` indicating padding elements. - - position_ids: An optional tensor of shape `(batch, seq_len)` indicating the position of each token in the sequence. - - past_key_value: An optional tuple of tensors containing past key and value states for caching. - - output_attentions: A boolean indicating whether to return the attention weights. - - use_cache: A boolean indicating whether to use caching for faster decoding. - - Returns: - - The output tensor of shape `(batch, seq_len, embed_dim)`. - - Optionally, the attention weights and past key-value states. - """ - residual = hidden_states - - if self.index != 0: - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class node: - """ - Node represents a node in a tree structure, with support for tracking parent-child relationships and depth. - It also allows for storing a value and a dictionary key associated with the node. - """ - - def __init__(self, parent=None, value=None, dict_key=None): - """ - Initializes a Node with an optional parent, value, and dictionary key. - - Parameters: - - parent: The parent node of this node. Default is None. - - value: The value associated with this node. Default is None. - - dict_key: The dictionary key associated with this node. Default is None. - """ - self.parent = parent - self.value = value - if parent: - self.depth = parent.depth + 1 - parent.children.append(self) - else: - self.depth = 0 - self.children = [] - self.dict_key = dict_key - - def is_leaf(self): - """ - Checks if the node is a leaf node (i.e., has no children). - - Returns: - - True if the node is a leaf, False otherwise. - """ - return len(self.children) == 0 - - def all_index(self): - """ - Returns the index path from the root to this node. - - Returns: - - A list of indices representing the path from the root to this node. - """ - if not self.parent.parent: - return [self.index] - else: - return self.parent.all_index() + [self.index] - -class Tree: - """ - Tree represents a tree structure, with support for tracking parent-child relationships, depth, and indexing. - It is initialized with a list of tree nodes, each represented as a list of values. - """ - - def __init__(self, tree_list): - """ - Initializes the Tree with the given list of tree nodes. - - Parameters: - - tree_list: A list of tree nodes, each represented as a list of values. - """ - sorted_tree_list = sorted(tree_list, key=lambda x: (len(x), x)) - self.root = node() - self.node_dic = {} - for tree_node in sorted_tree_list: - cur_value = tree_node[-1] - if len(tree_node) == 1: - cur_node = node(parent=self.root, value=cur_value, dict_key=tuple(tree_node)) - else: - cur_parent = self.node_dic[tuple(tree_node[:-1])] - cur_node = node(parent=cur_parent, value=cur_value, dict_key=tuple(tree_node)) - self.node_dic[tuple(tree_node)] = cur_node - self.indexnode() - - def max_depth(self): - """ - Returns the maximum depth of the tree. - - Returns: - - The maximum depth of the tree. - """ - return max([item.depth for item in self.node_dic.values()]) - - def num_node_wchild(self): - """ - Returns the number of nodes in the tree that have children. - - Returns: - - The number of nodes with children. - """ - num_c = 0 - for item in self.node_dic.values(): - if not item.is_leaf(): - num_c += 1 - return num_c - - def get_node_wchild(self): - """ - Returns a list of nodes in the tree that have children. - - Returns: - - A list of nodes with children. - """ - ns = [] - for item in self.node_dic.values(): - if not item.is_leaf(): - ns.append(item) - return ns - - def indexnode(self): - """ - Assigns an index to each non-leaf node in the tree. - """ - cur_index = 0 - for key in self.node_dic: - cur_node = self.node_dic[key] - if not cur_node.is_leaf(): - cur_node.index = cur_index - cur_index += 1 - - -def generate_tree_buffers_for_eagle(tree_choices, device="cuda"): - """ - Generates tree buffers for the EAGLE model based on the given tree choices. - - Parameters: - - tree_choices: A list of tree nodes, each represented as a list of values. - - device: The device to which the tensors should be moved. Default is "cuda". - - Returns: - - A dictionary containing the tree buffers. - """ - TOPK = 5 - tree = Tree(tree_choices) - tree_len = tree.num_node_wchild() - - max_depth = tree.max_depth() - nodes_wc = tree.get_node_wchild() - - depth_counts = [0 for _ in range(max_depth - 1)] - for x in nodes_wc: - depth_counts[x.depth - 1] += 1 - depth_counts_sum = [sum(depth_counts[:i + 1]) for i in range(len(depth_counts))] - - tree_attn_mask = torch.eye(tree_len, tree_len) - - for id, x in enumerate(nodes_wc): - tree_attn_mask[id, x.all_index()] = 1 - - tree_attn_mask_list0 = [tree_attn_mask[:ml, :ml] for ml in depth_counts_sum] - tree_attn_mask_list = [] - for id, x in enumerate(tree_attn_mask_list0): - x = x[-depth_counts[id]:] - tree_attn_mask_list.append(x) - - tree_indices_list = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] - repeat_nums = [[] for _ in depth_counts] - start = 0 - bias = 0 - for i in range(len(depth_counts)): - bias = 0 - repeat_j = 0 - for j in range(depth_counts[i]): - cur_node = nodes_wc[start + j] - cur_parent = cur_node.parent - - if j != 0: - if cur_parent != parent: - bias += 1 - parent = cur_parent - repeat_nums[i].append(j - repeat_j) - repeat_j = j - else: - parent = cur_parent - tree_indices_list[i][j] = cur_node.value + TOPK * (bias) - repeat_nums[i].append(j - repeat_j + 1) - start += depth_counts[i] - - position_ids = [torch.zeros(ml, dtype=torch.long) for ml in depth_counts] - - tree_buffers = { - "attn_mask": [i.unsqueeze(0).unsqueeze(0) for i in tree_attn_mask_list], - "tree_indices": tree_indices_list, - "position_ids": position_ids, - "repeat_nums": repeat_nums - } - - # Move the tensors in the dictionary to the specified device - tree_buffers = { - k: [i.clone().to(device) for i in v] - if isinstance(v[0], torch.Tensor) - else ( - torch.tensor(v, device=device) - if isinstance(v, torch.Tensor) - else v - ) - for k, v in tree_buffers.items() - } - return tree_buffers - -class EAGLEModel(nn.Module): - """ - EAGLEModel is a custom PyTorch model designed for the EAGLE architecture, which includes a series of decoder layers and a final linear layer. - It also supports gradient checkpointing and initialization of tree buffers for the EAGLE model. - """ - - def __init__(self, config, bias=True): - """ - Initializes the EAGLEModel with the given configuration and an optional bias parameter for the final linear layer. - - Parameters: - - config: A configuration object containing model parameters. - - bias: A boolean indicating whether to include bias in the final linear layer. Default is True. - """ - super().__init__() - self.gradient_checkpointing = False - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([EAGLEDecoderLayer(config, index) for index in range(config.num_hidden_layers)]) - self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=bias) - self.act = ACT2FN[config.hidden_act] - - def init_tree(self): - """ - Initializes the tree buffers for the EAGLE model. - """ - self.tree_buffer = generate_tree_buffers_for_eagle(self.tree, self.embed_tokens.weight.device) - - def reset(self): - """ - Resets the tree mask to None. - """ - self.tree_mask = None - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - """ - Prepares the decoder attention mask by creating a causal mask and combining it with the provided attention mask. - - Parameters: - - attention_mask: The provided attention mask. - - input_shape: The shape of the input tensor. - - inputs_embeds: The input embeddings. - - past_key_values_length: The length of past key values. - - Returns: - - The combined attention mask. - """ - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - torch.float32, # [MODIFIED] force to cast to float32 - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - # [MODIFIED] add tree mask - if hasattr(self, "tree_mask") and self.tree_mask is not None: - tree_mask = self.tree_mask - tree_len = tree_mask.size(-1) - bs = combined_attention_mask.size(0) - combined_attention_mask[:, :, -tree_len:, -tree_len:][ - tree_mask.repeat(bs, 1, 1, 1) == 0 - ] = torch.finfo(torch.float32).min - - return combined_attention_mask - - def forward( - self, - hidden_states, - input_ids, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ): - """ - Forward pass of the EAGLEModel. - - Parameters: - - hidden_states: The input hidden states. - - input_ids: The input token IDs. - - attention_mask: An optional attention mask. - - position_ids: An optional tensor of position IDs. - - past_key_values: An optional list of past key values. - - use_cache: A boolean indicating whether to use caching. - - output_attentions: A boolean indicating whether to output attention weights. - - output_hidden_states: A boolean indicating whether to output hidden states. - - Returns: - - The output hidden states and optionally the next decoder cache and hidden states. - """ - batch_size, seq_length, _ = hidden_states.shape - seq_length_with_past = seq_length - past_key_values_length = 0 - - with torch.no_grad(): - inputs_embeds = self.embed_tokens(input_ids) - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - device = hidden_states.device if hidden_states is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) - - inputs_embeds = inputs_embeds.to(hidden_states.dtype) - hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) - - all_hidden_states = () if output_hidden_states else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if use_cache: - return hidden_states, next_decoder_cache - - return hidden_states - - def reset_kv(self): - """ - Resets the stable key-value pairs to None. - """ - self.stable_kv = None - - @torch.no_grad() - def repeat_hidden(self, hidden_state, repeat_num): - """ - Repeats the hidden state according to the given repeat numbers. - - Parameters: - - hidden_state: The input hidden state. - - repeat_num: A list of repeat numbers. - - Returns: - - The repeated hidden state. - """ - new_hidden = [] - for id, i in enumerate(repeat_num): - new_hidden.append(hidden_state[:, id:id + 1].repeat(1, i, 1)) - return torch.cat(new_hidden, dim=1) - - - def sample(self, logits, logits_processor, k=1): - """ - Samples from the logits using the provided logits processor. - - Parameters: - - logits: The input logits. - - logits_processor: A function to process the logits. - - k: The number of samples to generate. Default is 1. - - Returns: - - The sampled indices, probabilities, and optionally the operation. - """ - bs, seq_len, _ = logits.shape - logits = logits.view(-1, logits.shape[-1]) - logits = logits_processor(None, logits) - probabilities = torch.nn.functional.softmax(logits, dim=-1) - sampled_indices = torch.multinomial(probabilities, k, replacement=False) - sampled_probs = torch.gather(probabilities, -1, sampled_indices) - cumulative_sum = torch.cumsum(sampled_probs, dim=-1) - cumulative_sum = torch.cat( - (torch.zeros(cumulative_sum.shape[0], 1, device=cumulative_sum.device), cumulative_sum[:, :-1]), dim=-1) - sampled_probs = sampled_probs / (1 - cumulative_sum) - sampled_probs[torch.isinf(sampled_probs)] = -1 - sampled_probs[torch.isnan(sampled_probs)] = -1 - sampled_probs = torch.clamp(sampled_probs, min=0.0, max=1.0) - sampled_indices = sampled_indices.view(bs, seq_len, -1) - sampled_probs = sampled_probs.view(bs, seq_len, -1) - probabilities = probabilities.view(bs, seq_len, -1) - - return sampled_indices, sampled_probs, probabilities - - @torch.no_grad() - def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_length=4, use_cache=True, - attention_mask=None, len_posi=None, ): - """ - Generates sequences using top-k sampling. - - Parameters: - - hidden_states: The input hidden states. - - input_ids: The input token IDs. - - head: The head layer for processing the hidden states. - - logits_processor: A function to process the logits. - - max_length: The maximum length of the generated sequence. Default is 4. - - use_cache: A boolean indicating whether to use caching. Default is True. - - attention_mask: An optional attention mask. - - len_posi: The length of the position IDs. - - Returns: - - The generated token indices, probabilities, and optionally the operation. - """ - top_k = 5 - bs = input_ids.shape[0] - input_ids = input_ids[:, 1:] - input_ids = input_ids.to(hidden_states.device) - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.to(self.device) - zero_num = position_ids.shape[1] - position_ids.max(dim=-1).values - 1 - zero_num = zero_num[:, None] - ss_token, ss_prob, ss_op = [], [], [] - if len_posi is None: - len_posi = input_ids.shape[1] - self.reset() - if use_cache: - if hasattr(self, "stable_kv") and self.stable_kv is not None: - kv_len = self.stable_kv[0][0].shape[2] - position_ids = position_ids[:, kv_len:] - out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, past_key_values=self.stable_kv, - use_cache=True, attention_mask=attention_mask, - position_ids=position_ids) - else: - out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, use_cache=True, - attention_mask=attention_mask, position_ids=position_ids) - self.stable_kv = past_key_values - last_nopadding = position_ids.argmax(dim=-1) - ab = tuple(range(bs)) - last_hidden = out_hidden[ab, last_nopadding][:, None] - if not self.diff_device: - last_headout = head(last_hidden) - else: - if hasattr(self, "layer_device"): - last_headout = head(last_hidden) - last_headout = last_headout.to(self.layer_device) - else: - last_headout = F.linear(last_hidden, self.headweight) - - for i in range(len(self.tree_buffer['tree_indices'])): - if logits_processor is not None: - topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) - else: - topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, - top_k, - dim=-1).values - op = None - - ss_token.append(topk_index) - ss_prob.append(topk_prob) - ss_op.append(op) - - input_ids = topk_index.view(bs, -1)[:, self.tree_buffer['tree_indices'][i]] - - attention_mask = torch.cat((attention_mask, torch.ones_like(input_ids, device=attention_mask.device, - dtype=attention_mask.dtype)), dim=1) - - if i == 0: - hidden_states = last_hidden - else: - hidden_states = out_hidden - hidden_states = self.repeat_hidden(hidden_states, self.tree_buffer["repeat_nums"][i]) - self.tree_mask = self.tree_buffer['attn_mask'][i] - position_ids = len_posi + self.tree_buffer["position_ids"][i][None, :] - zero_num - out_hidden, past_key_values = self(hidden_states, input_ids=input_ids, past_key_values=past_key_values, - position_ids=position_ids, use_cache=True, - attention_mask=attention_mask) - len_posi += 1 - - if not self.diff_device: - last_headout = head(out_hidden) - else: - if hasattr(self, "layer_device"): - last_headout = head(out_hidden) - last_headout = last_headout.to(self.layer_device) - else: - last_headout = F.linear(out_hidden[0], self.headweight) - - if logits_processor is not None: - topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) - else: - topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, top_k, - dim=-1).values - op = None - ss_token.append(topk_index) - ss_prob.append(topk_prob) - ss_op.append(op) - - else: - # TODO - pass - - return (torch.cat(ss_token, dim=1), torch.cat(ss_prob, dim=1), ss_op) - -def prepare_logits_processor( - temperature=0.0, repetition_penalty=0.0, top_p=0.0, top_k=0 -) -> LogitsProcessorList: - """ - Prepares a list of logits processors based on the provided parameters. - - Parameters: - - temperature: The temperature to apply to the logits before applying softmax. - - repetition_penalty: The penalty for repeating tokens. - - top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. - - top_k: The number of highest probability vocabulary tokens to keep for top-k filtering. - - Returns: - - A LogitsProcessorList containing the specified logits processors. - """ - processor_list = LogitsProcessorList() - if temperature >= 1e-5 and temperature != 1.0: - processor_list.append(TemperatureLogitsWarper(temperature)) - if repetition_penalty > 1.0: - processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) - if 1e-8 <= top_p < 1.0: - processor_list.append(TopPLogitsWarper(top_p)) - if top_k > 0: - processor_list.append(TopKLogitsWarper(top_k)) - return processor_list - - -def pad_path(path, length, pad_value=-2): - """ - Pad the given path list with a specific value up to a specified length. - - Parameters: - - path (list): The original list that needs padding. - - length (int): The desired length of the padded list. - - pad_value (optional, default=-2): The value to use for padding. - - Returns: - - list: A new list based on the original path but padded to the desired length. - - Example: - >>> pad_path([1,2,3], 5) - [1, 2, 3, -2, -2] - - Note: - If the given path is already longer than the specified length, - then no padding occurs, and the original path is returned. - """ - - # Calculate the number of padding values needed by subtracting the length - # of the path from the desired length. - # Append the padding values to the original path and return the new list. - return path + [pad_value] * (length - len(path)) - - -def generate_tree_buffers(tree_choices, device="cuda"): - """ - Generates tree buffers for a given set of tree choices. - - Parameters: - - tree_choices (list): A list of tree choices. - - device (str, optional): The device to use for tensor operations. Default is "cuda". - - Returns: - - A dictionary containing the generated tree buffers. - """ - TOPK = 5 - sorted_tree_choices = sorted(tree_choices, key=lambda x: (len(x), x)) - tree_len = len(sorted_tree_choices) + 1 - - # Initialize depth_counts to keep track of how many choices have a particular depth - depth_counts = [] - prev_depth = 0 - for path in sorted_tree_choices: - depth = len(path) - if depth != prev_depth: - depth_counts.append(0) - depth_counts[depth - 1] += 1 - prev_depth = depth - - tree_attn_mask = torch.eye(tree_len, tree_len) - tree_attn_mask[:, 0] = 1 - start = 0 - for i in range(len(depth_counts)): - for j in range(depth_counts[i]): - cur_tree_choice = sorted_tree_choices[start + j] - # retrieve ancestor position - if len(cur_tree_choice) == 1: - continue - ancestor_idx = [] - for c in range(len(cur_tree_choice) - 1): - ancestor_idx.append(sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) - tree_attn_mask[j + start + 1, ancestor_idx] = 1 - start += depth_counts[i] - - tree_indices = torch.zeros(tree_len, dtype=torch.long) - p_indices = [0 for _ in range(tree_len - 1)] - b_indices = [[] for _ in range(tree_len - 1)] - tree_indices[0] = 0 - start = 0 - bias = 0 - for i in range(len(depth_counts)): - inlayer_bias = 0 - b = [] - for j in range(depth_counts[i]): - cur_tree_choice = sorted_tree_choices[start + j] - cur_parent = cur_tree_choice[:-1] - if j != 0: - if cur_parent != parent: - bias += 1 - inlayer_bias += 1 - parent = cur_parent - b = [] - else: - parent = cur_parent - tree_indices[start + j + 1] = cur_tree_choice[-1] + TOPK * (i + bias) + 1 - p_indices[start + j] = inlayer_bias - if len(b) > 0: - b_indices[start + j] = copy.deepcopy(b) - else: - b_indices[start + j] = [] - b.append(cur_tree_choice[-1] + TOPK * (i + bias) + 1) - start += depth_counts[i] - - p_indices = [-1] + p_indices - tree_position_ids = torch.zeros(tree_len, dtype=torch.long) - start = 0 - for i in range(len(depth_counts)): - tree_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 - start += depth_counts[i] - - retrieve_indices_nest = [] - retrieve_paths = [] - for i in range(len(sorted_tree_choices)): - cur_tree_choice = sorted_tree_choices[-i - 1] - retrieve_indice = [] - if cur_tree_choice in retrieve_paths: - continue - else: - for c in range(len(cur_tree_choice)): - retrieve_indice.append(sorted_tree_choices.index(cur_tree_choice[:c + 1])) - retrieve_paths.append(cur_tree_choice[:c + 1]) - retrieve_indices_nest.append(retrieve_indice) - max_length = max([len(x) for x in retrieve_indices_nest]) - retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] - retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) - retrieve_indices = retrieve_indices + 1 - retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], - dim=1) - - maxitem = retrieve_indices.max().item() + 5 - - def custom_sort(lst): - # sort_keys=[len(list)] - sort_keys = [] - for i in range(len(lst)): - sort_keys.append(lst[i] if lst[i] >= 0 else maxitem) - return sort_keys - - retrieve_indices = retrieve_indices.tolist() - retrieve_indices = sorted(retrieve_indices, key=custom_sort) - retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) - - p_indices = torch.tensor(p_indices) - p_indices_new = p_indices[retrieve_indices] - p_indices_new = p_indices_new.tolist() - - b_indices = [[]] + b_indices - b_indices_new = [] - for ib in range(retrieve_indices.shape[0]): - iblist = [] - for jb in range(retrieve_indices.shape[1]): - index = retrieve_indices[ib, jb] - if index == -1: - iblist.append([]) - else: - b = b_indices[index] - if len(b) > 0: - bt = [] - for bi in b: - bt.append(torch.where(tree_indices == bi)[0].item()) - iblist.append(torch.tensor(bt, device=device)) - else: - iblist.append(b) - b_indices_new.append(iblist) - - # Aggregate the generated buffers into a dictionary - tree_buffers = { - "tree_attn_mask": tree_attn_mask.unsqueeze(0).unsqueeze(0), - "tree_indices": tree_indices, - "tree_position_ids": tree_position_ids, - "retrieve_indices": retrieve_indices, - } - - # Move the tensors in the dictionary to the specified device - tree_buffers = { - k: v.clone().to(device) - if isinstance(v, torch.Tensor) - else torch.tensor(v, device=device) - for k, v in tree_buffers.items() - } - tree_buffers["p_indices"] = p_indices_new - tree_buffers["b_indices"] = b_indices_new - return tree_buffers - -def _prepare_decoder_attention_mask( - attention_mask, tree_mask, input_shape, inputs_embeds, past_key_values_length -): - """ - Prepares the decoder attention mask by combining causal masks and attention masks. - - Parameters: - - attention_mask (torch.Tensor): The attention mask tensor. - - tree_mask (torch.Tensor): The tree mask tensor. - - input_shape (tuple): The shape of the input tensor. - - inputs_embeds (torch.Tensor): The input embeddings tensor. - - past_key_values_length (int): The length of past key values. - - Returns: - - combined_attention_mask (torch.Tensor): The combined attention mask tensor. - """ - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - torch.float32, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - if tree_mask is not None: - tree_len = tree_mask.size(-1) - bs = combined_attention_mask.size(0) - combined_attention_mask[:, :, -tree_len:, -tree_len:][ - tree_mask.repeat(bs, 1, 1, 1) == 0 - ] = combined_attention_mask.min() - - return combined_attention_mask - -@torch.no_grad() -def forward_with_tree_mask( - model, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - tree_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values=None, # [MODIFIED] past_key_value is KVCache class - inputs_embeds: Optional[torch.FloatTensor] = None, -): - """ - Forward pass through the model with tree mask applied. - - Parameters: - - model (torch.nn.Module): The model to forward through. - - input_ids (torch.LongTensor): The input IDs tensor. - - attention_mask (torch.Tensor, optional): The attention mask tensor. - - tree_mask (torch.Tensor, optional): The tree mask tensor. - - position_ids (torch.LongTensor, optional): The position IDs tensor. - - past_key_values (KVCache, optional): The past key values. - - inputs_embeds (torch.FloatTensor, optional): The input embeddings tensor. - - Returns: - - hidden_states (torch.Tensor): The output hidden states tensor. - - next_cache (Cache): The next cache. - """ - output_attentions = False - use_cache = True - batch_size, seq_length = input_ids.shape - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - inputs_embeds = model.embed_tokens(input_ids) - - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = _prepare_decoder_attention_mask( - attention_mask, - tree_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - hidden_states = inputs_embeds - next_decoder_cache = () - - for idx, decoder_layer in enumerate(model.layers): - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - hidden_states = model.norm(hidden_states) - - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - return hidden_states, next_cache - -def initialize_tree(input_ids, model, logits_processor, attention_mask=None): - """ - Initializes the tree by forwarding through the model with a tree mask applied. - - Parameters: - - input_ids (torch.LongTensor): The input IDs tensor. - - model (torch.nn.Module): The model to forward through. - - logits_processor (LogitsProcessorList, optional): An instance of LogitsProcessorList used to modify the prediction scores. - - attention_mask (torch.Tensor, optional): The attention mask tensor. - - Returns: - - tree_logits (torch.Tensor): The tree logits tensor. - - logits (torch.Tensor): The logits tensor. - - hidden_states (torch.Tensor): The hidden states tensor. - - token (torch.Tensor): The token tensor. - - past_key_value (KVCache): The past key values. - """ - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - hidden_states, past_key_value = forward_with_tree_mask(model.base_model.model, input_ids=input_ids, - attention_mask=attention_mask, position_ids=position_ids) - logits = model.base_model.lm_head(hidden_states) - - if logits_processor is not None: - sample_logits = logits[:, -1] - sample_logits = logits_processor(None, sample_logits) - probabilities = torch.nn.functional.softmax(sample_logits, dim=-1) - token = torch.multinomial(probabilities, 1) - else: - token = torch.argmax(logits[:, -1], dim=-1) - token = token[:, None] - input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) - - tree_logits = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head, logits_processor, - attention_mask=attention_mask) - - return tree_logits, logits, hidden_states, token, past_key_value - -def generate_candidates(tree_logits, tree_indices, retrieve_indices, sample_token, logits_processor): - """ - Generates candidates for the tree. - - Parameters: - - tree_logits (torch.Tensor): The tree logits tensor. - - tree_indices (torch.Tensor): The tree indices tensor. - - retrieve_indices (torch.Tensor): The retrieve indices tensor. - - sample_token (torch.Tensor): The sample token tensor. - - logits_processor (LogitsProcessorList, optional): An instance of LogitsProcessorList used to modify the prediction scores. - - Returns: - - cart_candidates (torch.Tensor): The cartesian product candidates tensor. - - cart_candidates_prob (torch.Tensor, optional): The cartesian product candidates probabilities tensor. - - tree_candidates (torch.Tensor): The tree candidates tensor. - """ - bs = sample_token.shape[0] - sample_token = sample_token.to(tree_indices.device) - - candidates_logit = sample_token - candidates_tree_logits = tree_logits[0] - - candidates = torch.cat([candidates_logit, candidates_tree_logits.view(bs, -1)], dim=-1) - - tree_candidates = candidates[:, tree_indices] - - tree_candidates_ext = torch.cat( - [tree_candidates, torch.zeros((bs, 1), dtype=torch.long, device=tree_candidates.device)-1], dim=-1) - - cart_candidates = tree_candidates_ext[:, retrieve_indices] - - if logits_processor is not None: - candidates_tree_prob = tree_logits[1] - candidates_prob = torch.cat( - [torch.ones((bs, 1), device=candidates_tree_prob.device, dtype=torch.float32), - candidates_tree_prob.view(bs, -1)], - dim=-1) - - tree_candidates_prob = candidates_prob[:, tree_indices] - tree_candidates_prob_ext = torch.cat( - [tree_candidates_prob, torch.ones((bs, 1), dtype=torch.float32, device=tree_candidates_prob.device)], - dim=-1) - cart_candidates_prob = tree_candidates_prob_ext[:, retrieve_indices] - else: - cart_candidates_prob = None - - return cart_candidates, cart_candidates_prob, tree_candidates - -def tree_decoding( - model, - tree_candidates, - past_key_values, - tree_position_ids, - input_ids, - retrieve_indices, - attention_mask=None, - tree_mask=None, -): - - zero_num = attention_mask.shape[1]-attention_mask.long().sum(-1) - zero_num = zero_num[:, None] - position_ids = tree_position_ids[None,:] + input_ids.shape[1]-zero_num - - - attention_mask = torch.cat( - (attention_mask, torch.ones_like(tree_candidates, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1) - - hidden_states, past_key_value = forward_with_tree_mask(model.base_model.model, input_ids=tree_candidates,past_key_values=past_key_values, - attention_mask=attention_mask, tree_mask=tree_mask,position_ids=position_ids) - - tree_logits = model.base_model.lm_head(hidden_states) - - - - - logits = tree_logits[:, retrieve_indices] - return logits, hidden_states,past_key_value - - -def evaluate_posterior( - logits, candidates, logits_processor, cart_candidates_prob, op, p_indices, tree_candidates, b_indices, - finish_flag -): - """ - Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. - - Depending on the temperature value, the function either uses greedy decoding or evaluates posterior - probabilities to select the best candidate. - - Args: - - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). - - candidates (torch.Tensor): Candidate token sequences. - - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. - - posterior_threshold (float): Threshold for posterior probability. - - posterior_alpha (float): Scaling factor for the threshold. - - Returns: - - best_candidate (torch.Tensor): Index of the chosen best candidate. - - accept_length (int): Length of the accepted candidate sequence. - """ - # Greedy decoding based on temperature value - if logits_processor is None: - bs = tree_candidates.size(0) - # Find the tokens that match the maximum logits for each position in the sequence - posterior_mask = ( - candidates[:, :, 1:].to(logits.device) == torch.argmax(logits[:, :, :-1], dim=-1) - ).int() - candidates_accept_length = (torch.cumprod(posterior_mask, dim=-1)).sum(dim=-1) - accept_length = candidates_accept_length.max(dim=1).values - best_candidate = torch.argmax(candidates_accept_length, dim=-1).to(torch.long) - - - bt = tuple(range(bs)) - logits_batch = logits[bt, best_candidate, accept_length, :] - accept_length = accept_length.tolist() - - for batch in range(bs): - if finish_flag[batch]: - accept_length[batch] = 0 - - return best_candidate.tolist(), accept_length, logits_batch - - else: - cart_candidates_prob = cart_candidates_prob.to(logits.device) - bs = cart_candidates_prob.size(0) - - logits = logits_processor(None, logits) - probs = torch.softmax(logits, dim=-1) - - best_candidate_list = [] - accept_length_list = [] - sample_p_list = [] - - for batch in range(bs): - accept_length = 1 - accept_cand = candidates[batch, 0, :1] - best_candidate = 0 - for i in range(1, candidates.shape[2]): - if i != accept_length: - break - adjustflag = False - is_eq = (candidates[batch, :, :accept_length] == accept_cand).all(dim=1) - fi = torch.nonzero(is_eq, as_tuple=True)[0][0] - gtp = probs[batch, fi, i - 1] - candidates_set = [] - for j in range(candidates.shape[1]): - if is_eq[j]: - x = candidates[batch, j, i] - xi = x.item() - if xi in candidates_set or xi == -1: - continue - candidates_set.append(xi) - r = random.random() - px = gtp[xi] - qx = cart_candidates_prob[batch, j, i] - if qx <= 0: - continue - acp = px / qx - if r <= acp: - accept_cand = torch.cat((accept_cand, x[None]), dim=0) - accept_length += 1 - best_candidate = j - break - else: - q = op[i - 1][batch][p_indices[j][i]].clone() - b = b_indices[j][i] - if len(b) > 0: - mask = tree_candidates[batch][b] - q[mask] = 0 - q = q / q.sum() - gtp = gtp - q - gtp[gtp < 0] = 0 - gtp = gtp / gtp.sum() - adjustflag = True - if adjustflag and accept_length != candidates.shape[1]: - sample_p = gtp - else: - sample_p = probs[batch, best_candidate, accept_length - 1] - best_candidate_list.append(best_candidate) - accept_length_list.append(accept_length - 1) - sample_p_list.append(sample_p) - - for batch in range(bs): - if finish_flag[batch]: - accept_length_list[batch] = 0 - - return best_candidate_list, accept_length_list, sample_p_list - -@torch.no_grad() -def update_inference_inputs( - input_ids, - attention_mask, - candidates, - best_candidate, - accept_length, - retrieve_indices, - logits_processor, - new_token, - past_key_values, - model, - hidden_state_new, - sample_p, - finish_flag -): - """ - Updates inference inputs based on the best candidate selected. - - Parameters: - - input_ids (torch.Tensor): The input IDs tensor. - - attention_mask (torch.Tensor): The attention mask tensor. - - candidates (torch.Tensor): The candidates tensor. - - best_candidate (torch.Tensor): The best candidate tensor. - - accept_length (list): The list of accepted lengths. - - retrieve_indices (torch.Tensor): The retrieve indices tensor. - - logits_processor (LogitsProcessorList, optional): An instance of LogitsProcessorList used to modify the prediction scores. - - new_token (torch.Tensor): The new token tensor. - - past_key_values (KVCache): The past key values. - - model (torch.nn.Module): The model to forward through. - - hidden_state_new (torch.Tensor): The new hidden state tensor. - - sample_p (torch.Tensor): The sample probability tensor. - - finish_flag (list): The list indicating whether each batch has finished. - - Returns: - - input_ids (torch.Tensor): The updated input IDs tensor. - - tree_logits (torch.Tensor): The tree logits tensor. - - new_token (torch.Tensor): The updated new token tensor. - - None: Placeholder for a return value. - - token (torch.Tensor): The token tensor. - - attention_mask (torch.Tensor): The updated attention mask tensor. - - finish_flag (list): The updated finish flag list. - - new_outs (list): The list of new outputs. - - new_kv (tuple): The updated past key values. - """ - new_outs = [] - finish_flag = copy.deepcopy(finish_flag) - bs = len(best_candidate) - prev_input_len = input_ids.shape[1] - max_acccept_len = max(accept_length) - - retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices[0]] - - ab = tuple(range(bs)) - select_indices = ( - retrieve_indices.cpu()[ab, best_candidate, : max_acccept_len + 1, ...] + prev_input_len - ) - new_input_ids = candidates[ab, best_candidate, : max_acccept_len + 1, ...] - - draft_hidden = retrieve_hidden_state_new[ab, best_candidate, :max_acccept_len + 1] - - new_attention_mask = torch.zeros((bs, max_acccept_len + 1), dtype=torch.long) - - for batch in range(bs): - new_attention_mask[batch, :accept_length[batch] + 1] = 1 - new_o = new_input_ids[batch, : accept_length[batch] + 1].tolist() - new_outs.append(new_o) - if model.base_model.config.eos_token_id in new_o: - finish_flag[batch] = True - new_token[batch] += accept_length[batch] + 1 - - attention_mask = torch.cat((attention_mask, new_attention_mask.to(attention_mask.device)), dim=1) - - batch_dim_indices = torch.tensor(ab)[:, None].expand(-1, max_acccept_len + 1) - - new_kv = () - - for past_key_values_data in past_key_values: - layer_kv = () - for korv in past_key_values_data: - tgt = korv[batch_dim_indices, :, select_indices, :] - tgt = tgt.permute(0, 2, 1, 3) - dst = korv[:, :, prev_input_len: prev_input_len + tgt.shape[-2], :] - dst.copy_(tgt, non_blocking=True) - layer_kv += (korv[:, :, : prev_input_len + tgt.shape[-2], :],) - new_kv += (layer_kv,) - - input_ids = torch.cat((input_ids, new_input_ids.to(input_ids.device)), dim=1) - - prob = sample_p - if isinstance(prob, list): - prob = torch.stack(prob) - if logits_processor is not None: - token = torch.multinomial(prob, 1) - else: - token = torch.argmax(prob, dim=-1) - token = token[:, None] - - draft_input_ids = torch.cat((new_input_ids, torch.ones(bs, 1, dtype=torch.long, device=new_input_ids.device)), dim=1) - token_ = token[:, 0] - - draft_input_ids[ab, torch.tensor(accept_length, dtype=torch.long) + 1] = token_ - - tree_logits = model.ea_layer.topK_genrate(draft_hidden, - input_ids=draft_input_ids, - head=model.base_model.lm_head, logits_processor=logits_processor, attention_mask=attention_mask, len_posi=input_ids.shape[1]) - - return input_ids, tree_logits, new_token, None, token, attention_mask, finish_flag, new_outs, new_kv - - -class EAGLE: - """ - EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty - - This class integrates a Huggingface causal LLM with the EAGLE model for enhanced text generation capabilities. - - Attributes: - - base_model: The Huggingface causal LLM. - - eagle_path: The path to the EAGLE model. - - use_tree_attn: Whether to use tree attention. Recommended for strong computational power and small batch sizes. - - ea_layer: The EAGLE model layer. - - tree: The tree structure used for attention. - - The following models can be directly accelerated using their corresponding checkpoints. - vicuna-7b-v1.3: yuhuili/EAGLE-Vicuna-7B-v1.3 - vicuna-13b-v1.3: yuhuili/EAGLE-Vicuna-13B-v1.3 - vicuna-33b-v1.3: yuhuili/EAGLE-Vicuna-33B-v1.3 - LLaMA2-Chat-7B: yuhuili/EAGLE-llama2-chat-7B - LLaMA2-Chat-13B: yuhuili/EAGLE-llama2-chat-13B - LLaMA2-Chat-70B: yuhuili/EAGLE-llama2-chat-70B - Mixtral-8x7B-Instruct-v0.1: yuhuili/EAGLE-mixtral-instruct-8x7B - Other models need to be trained independently. - Please refer to https://github.com/SafeAILab/EAGLE for more information. - - """ - def __init__(self, base_model, eagle_path, use_tree_attn=True): - """ - Initializes the EAGLE class with the base model and EAGLE model path. - - Args: - - base_model: Huggingface causal LLM. - - eagle_path: Path of EAGLE. - - use_tree_attn: Whether to use tree attention. Default is True. - """ - self.base_model = base_model - - configpath = os.path.join(eagle_path, "config.json") - if not os.path.exists(configpath): - configpath = hf_hub_download(eagle_path, "config.json") - - load_model_path = os.path.join(eagle_path, "pytorch_model.bin") - if not os.path.exists(load_model_path): - load_model_path = hf_hub_download(eagle_path, "pytorch_model.bin") - - config = EAGLE_Config.from_pretrained(configpath) - with open(configpath, "r") as f: - con = json.loads(f.read()) - try: - bias = con["bias"] - except: - bias = True - - self.ea_layer = EAGLEModel(config, bias=bias) - - ea_layer_state_dict = torch.load(load_model_path, map_location="cpu") - self.ea_layer.load_state_dict(ea_layer_state_dict, strict=True) - device = base_model.model.layers[-1].self_attn.q_proj.weight.device - self.ea_layer.to(self.base_model.dtype).to(device) - self.ea_layer.device = device - - if device != base_model.lm_head.weight.device: - self.ea_layer.diff_device = True - self.ea_layer.headweight = base_model.lm_head.weight.clone().to(device) - else: - self.ea_layer.diff_device = False - - if use_tree_attn: - self.ea_layer.tree = tree_structure - self.tree = tree_structure - else: - self.ea_layer.tree = chain_structure - self.tree = chain_structure - - self.ea_layer.init_tree() - self.base_model.eval() - self.ea_layer.eval() - - @torch.no_grad() - def generate( - self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.LongTensor] = None, - temperature=0.0, - top_p=0.0, - top_k=0, - max_new_tokens=512, - max_length=2048, - ) -> torch.LongTensor: - """ - Generates text using the EAGLE model. - - Args: - - input_ids: The input IDs tensor. - - attention_mask: The attention mask tensor. - - temperature: The temperature for sampling. Default is 0.0. - - top_p: The top-p value for nucleus sampling. Default is 0.0. - - top_k: The top-k value for top-k sampling. Default is 0. - - max_new_tokens: The maximum number of new tokens to generate. Default is 512. - - max_length: The maximum length of the generated text. Default is 2048. - - Returns: - - torch.LongTensor: The generated text as a tensor of token IDs. - """ - tree_choices = self.tree - bs = input_ids.shape[0] - if temperature > 1e-5: - logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) - else: - logits_processor = None - - input_ids = input_ids.clone() - self.ea_layer.reset_kv() - - if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: - tree_buffers = self.tree_buffers - else: - tree_buffers = generate_tree_buffers( - tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device - ) - tree_buffers["retrieve_indices_head"] = tree_buffers["retrieve_indices"].to( - self.base_model.lm_head.weight.device) - tree_buffers["tree_position_ids"] = tree_buffers["tree_position_ids"].to(self.base_model.device) - self.tree_buffers = tree_buffers - self.tree_choices = tree_choices - - tree_buffers["retrieve_indices_batch"] = tree_buffers["retrieve_indices"].expand(bs, -1, -1) - - input_ids=input_ids.to(tree_buffers["tree_position_ids"].device) - attention_mask=attention_mask.to(tree_buffers["tree_position_ids"].device) - - bool_mask = attention_mask.bool() - - out_inputids = [ids.tolist() for ids, mask in zip(input_ids, bool_mask)] - - tree_logits, logits, hidden_state, sample_token,past_key_values = initialize_tree( - input_ids, self, logits_processor, - attention_mask=attention_mask - ) - - new_token = [0]*bs - finish_flag=[False]*bs - - for idx in range(max_length): - candidates, cart_candidates_prob, tree_candidates = generate_candidates( - tree_logits, - tree_buffers["tree_indices"], - tree_buffers["retrieve_indices"], - sample_token, - logits_processor - ) - logits, hidden_state_new,past_key_values = tree_decoding( - self, - tree_candidates, - past_key_values, - tree_buffers["tree_position_ids"], - input_ids, - tree_buffers["retrieve_indices_head"], - attention_mask=attention_mask, - tree_mask=tree_buffers["tree_attn_mask"] - ) - best_candidate, accept_length, sample_p = evaluate_posterior( - logits, candidates, logits_processor, cart_candidates_prob, tree_logits[2], tree_buffers["p_indices"], - tree_candidates, tree_buffers["b_indices"],finish_flag - ) - input_ids, tree_logits, new_token, hidden_state, sample_token, attention_mask, newfinish_flag, new_outs,past_key_values = update_inference_inputs( - input_ids, - attention_mask, - candidates, - best_candidate, - accept_length, - tree_buffers["retrieve_indices_batch"], - logits_processor, - new_token, - past_key_values, - self, - hidden_state_new, - sample_p, - finish_flag - ) - - min_uf_newtokens = max_length + 10 - for batch in range(bs): - if not finish_flag[batch]: - out_inputids[batch].extend(new_outs[batch]) - min_uf_newtokens = min(min_uf_newtokens, new_token[batch]) - finish_flag = newfinish_flag - - if min(finish_flag): - break - if min_uf_newtokens > max_new_tokens: - break - if input_ids.shape[1] + 10 + len(tree_choices) > max_length: - break - - if len(out_inputids)==1: - return out_inputids[0] - return out_inputids From 21bc11f5defa954bdc587453c4470f60a6a62478 Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Tue, 23 Apr 2024 00:04:25 -0500 Subject: [PATCH 08/12] Update README.md Signed-off-by: Siddhi Velankar --- examples/huggingface/pytorch/text-generation/eagle/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/README.md b/examples/huggingface/pytorch/text-generation/eagle/README.md index 3d716e96a4d..164063b938b 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/README.md +++ b/examples/huggingface/pytorch/text-generation/eagle/README.md @@ -7,7 +7,7 @@ See [here](https://arxiv.org/abs/2401.15077) to view the paper and [here](https: ## Setup and installation -With pip: +With pip: (recommended) ```bash pip install eagle-llm From d089c8ab10d396bbe3ea1908391bdf1263257806 Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Tue, 23 Apr 2024 00:30:00 -0500 Subject: [PATCH 09/12] tps logic Signed-off-by: Siddhi Velankar --- .../huggingface/pytorch/text-generation/eagle/eagle_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py index c41434ad12a..eabcf2e7063 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py +++ b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py @@ -124,7 +124,7 @@ def parse_argument(): # Update tokens per second based on outs type if isinstance(outs, list): - total_new_tokens = len(outs) + total_new_tokens += len(outs) else: total_new_tokens += len(outs[0]) - inputs.input_ids.shape[-1] # Print output From a63a9bd4e8ddf0e094f479998b9e07af58155c20 Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Tue, 23 Apr 2024 00:55:18 -0500 Subject: [PATCH 10/12] Update README.md Signed-off-by: Siddhi Velankar --- .../pytorch/text-generation/eagle/README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/README.md b/examples/huggingface/pytorch/text-generation/eagle/README.md index 164063b938b..b7c20328b62 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/README.md +++ b/examples/huggingface/pytorch/text-generation/eagle/README.md @@ -20,7 +20,7 @@ git clone https://github.com/SafeAILab/EAGLE.git cd EAGLE pip install -e . ``` -The modeling_eagle.py script can be found at intel_extension_for_transformers/transformers/modeling/modeling_eagle.py + ## Usage @@ -45,3 +45,13 @@ The default base model is set to "meta-llama/Llama-2-7b-chat-hf", you can change - EAGLE Model Integration: The script demonstrates how to integrate the EAGLE model for enhanced text generation capabilities. - Performance Measurement: The script measures the performance of the text generation process, including the tokens per second (TPS) based on the total new tokens and total time. + +# Results + +We conducted benchmarking tests on both CPU and XPU environments. + +On GPU, llama2-7b-chat + EAGLE outperforms llama2-7b-chat by 3x speed in generating new tokens. +On CPU, llama2-7b-chat + EAGLE outperforms llama2-7b-chat by 1.75x speed in generating new tokens. + +This highlights the superior capabilities of the EAGLE in handling the text generation task. + From 81671166d3b485f0e7181ce11a9be822ea59643f Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Tue, 23 Apr 2024 00:57:45 -0500 Subject: [PATCH 11/12] Update README.md Signed-off-by: Siddhi Velankar --- examples/huggingface/pytorch/text-generation/eagle/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/README.md b/examples/huggingface/pytorch/text-generation/eagle/README.md index b7c20328b62..e5dacddd01b 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/README.md +++ b/examples/huggingface/pytorch/text-generation/eagle/README.md @@ -33,7 +33,7 @@ The script accepts several command-line arguments: ```bash python3 eagle_example.py -python main.py -d xpu --max_new_tokens 1024 --use_eagle +python eagle_example.py -d xpu --max_new_tokens 1024 --use_eagle ``` From 8de519dc4ffaac7b8d65948c2b4b007393f76ab8 Mon Sep 17 00:00:00 2001 From: Siddhi Velankar Date: Tue, 30 Apr 2024 10:20:12 -0500 Subject: [PATCH 12/12] text to prompt Signed-off-by: Siddhi Velankar --- .../pytorch/text-generation/eagle/eagle_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py index eabcf2e7063..05af082b2de 100644 --- a/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py +++ b/examples/huggingface/pytorch/text-generation/eagle/eagle_example.py @@ -108,8 +108,8 @@ def parse_argument(): for _ in tqdm(range(10)): t_b = time.time() - # inputs = tokenizer(text, return_tensors="pt", padding=True).to(device) - inputs = tokenizer(text, return_tensors="pt").to(device) + # inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device) + inputs = tokenizer(prompt, return_tensors="pt").to(device) # Generate text using the EAGLE model if the flag is set, otherwise use the base model if use_eagle: