Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
…n-for-transformers into hengguo/h2o
  • Loading branch information
n1ck-guo committed Jun 25, 2024
2 parents a9488b4 + 2d82bb5 commit 3c185ad
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,10 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

self._init_func = []

def register_init_func(self, func):
self._init_func.append(func)

def post_init(self):
for func in self._init_func:
func(self)
Expand Down Expand Up @@ -954,13 +954,13 @@ def __init__(
self.model.layers[layer_idx].self_attn.post_init()

self.model.layers[layer_idx].self_attn.pruner = self.pruner

# Initialize weights and apply final processing
self.post_init()

def _generate(*args, **kwargs):
self.pruner.before_generate(self, *args, **kwargs)
result = self.ori_generate(*args, **kwargs)
self.pruner.before_generate(self, *args, **kwargs)
result = self.ori_generate(*args, **kwargs)
self.pruner.after_generate(self,*args, **kwargs)
return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def __init__(self, config: H2OConfig) -> None:
self.config = config
self.real_drop = self.config.real_drop
self.prune_kv_cache_size = None


def self_attn_init(self, module):
module.h2o_kv_cache = H2OKVCache(
self.config.heavy_ratio,
Expand All @@ -215,7 +215,7 @@ def self_attn_init(self, module):
self.config.h2o_min_seqlen,
self.config.mean
)

def before_generate(self, model, inputs, *args, **kwargs):
self.past_length = 0
max_length = kwargs['max_new_tokens'] if kwargs.get('max_new_tokens') else kwargs['max_length']
Expand All @@ -234,7 +234,7 @@ def after_generate(self, model, inputs, *args, **kwargs):
if "Attention" in module.__class__.__name__:
module.h2o_kv_cache.clean_scores()
self.prune_kv_cache_size = None

def prune(self, module, query_states, key_states, value_states, causal_mask=None, **kwargs):
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(module.head_dim)
if causal_mask is not None: # no matter the length, we just slice it
Expand Down

0 comments on commit 3c185ad

Please sign in to comment.