From 02e4c1cd3fa748c3a630ffcde3bcd284b9071fd3 Mon Sep 17 00:00:00 2001 From: naubull2 Date: Tue, 3 Oct 2023 17:47:40 +0900 Subject: [PATCH] [remove] unused comments --- gptneox_attn_replace.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index fda6a221..ee16bb0b 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -112,9 +112,7 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) - #qkv = qkv.transpose(1, 2) - # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) @@ -124,6 +122,7 @@ def shift(qkv, num_heads, head_dim): query = shift(query, self.num_attention_heads, self.head_size).contiguous() key = shift(key, self.num_attention_heads, self.head_size).contiguous() value = shift(value, self.num_attention_heads, self.head_size).contiguous() + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) # Compute attention @@ -136,7 +135,6 @@ def shift(qkv, num_heads, head_dim): if self.training and not use_full: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) - #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2)