From 2d82bb54dba595261adfd7a737e24e86a8312824 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 06:28:33 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../models/modeling_gaudi_llama.py | 10 +++++----- .../modeling/kv_cache_compression/prune/h2o.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gaudi_llama.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gaudi_llama.py index ac2cba82bcf..ab6da70bc8e 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gaudi_llama.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/models/modeling_gaudi_llama.py @@ -294,10 +294,10 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.norm_factor = 1.0 / math.sqrt(self.head_dim) self._init_func = [] - + def register_init_func(self, func): self._init_func.append(func) - + def post_init(self): for func in self._init_func: func(self) @@ -946,13 +946,13 @@ def __init__( self.model.layers[layer_idx].self_attn.post_init() self.model.layers[layer_idx].self_attn.pruner = self.pruner - + # Initialize weights and apply final processing self.post_init() def _generate(*args, **kwargs): - self.pruner.before_generate(self, *args, **kwargs) - result = self.ori_generate(*args, **kwargs) + self.pruner.before_generate(self, *args, **kwargs) + result = self.ori_generate(*args, **kwargs) self.pruner.after_generate(self,*args, **kwargs) return result diff --git a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py index 802ebc495be..ed91b8ddb7b 100644 --- a/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py +++ b/intel_extension_for_transformers/transformers/modeling/kv_cache_compression/prune/h2o.py @@ -204,8 +204,8 @@ def __init__(self, config: H2OConfig) -> None: self.config = config self.real_drop = self.config.real_drop self.prune_kv_cache_size = None - - + + def self_attn_init(self, module): module.h2o_kv_cache = H2OKVCache( self.config.heavy_ratio, @@ -215,7 +215,7 @@ def self_attn_init(self, module): self.config.h2o_min_seqlen, self.config.mean ) - + def before_generate(self, model, inputs, *args, **kwargs): self.past_length = 0 max_length = kwargs['max_new_tokens'] if kwargs.get('max_new_tokens') else kwargs['max_length'] @@ -234,7 +234,7 @@ def after_generate(self, model, inputs, *args, **kwargs): if "Attention" in module.__class__.__name__: module.h2o_kv_cache.clean_scores() self.prune_kv_cache_size = None - + def prune(self, module, query_states, key_states, value_states, causal_mask=None, **kwargs): attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(module.head_dim) if causal_mask is not None: # no matter the length, we just slice it