diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py index fda6a221..ee16bb0b 100644 --- a/gptneox_attn_replace.py +++ b/gptneox_attn_replace.py @@ -112,9 +112,7 @@ def shift(qkv, num_heads, head_dim): qkv = qkv.transpose(1, 2) # qkv = [bsz, q_len, nh, d] qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) - #qkv = qkv.transpose(1, 2) - # TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well # -> [bsz * n_group, group_s, nh, d) # -> [bsz * n_group, nh, group_s, d) qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) @@ -124,6 +122,7 @@ def shift(qkv, num_heads, head_dim): query = shift(query, self.num_attention_heads, self.head_size).contiguous() key = shift(key, self.num_attention_heads, self.head_size).contiguous() value = shift(value, self.num_attention_heads, self.head_size).contiguous() + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) # Compute attention @@ -136,7 +135,6 @@ def shift(qkv, num_heads, head_dim): if self.training and not use_full: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) - #attn_output = attn_output.transpose(1, 2) # [bsz, q_len, nh, hd] attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) attn_output = attn_output.transpose(1, 2)