Skip to content

Commit

Permalink
[remove] unused comments
Browse files Browse the repository at this point in the history
  • Loading branch information
naubull2 committed Oct 3, 2023
1 parent 8a11ef8 commit 02e4c1c
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions gptneox_attn_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def shift(qkv, num_heads, head_dim):
qkv = qkv.transpose(1, 2)
# qkv = [bsz, q_len, nh, d]
qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1)
#qkv = qkv.transpose(1, 2)

# TODO: Changing the q_len to group_size, will require attention mask to be adjusted as well
# -> [bsz * n_group, group_s, nh, d)
# -> [bsz * n_group, nh, group_s, d)
qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2)
Expand All @@ -124,6 +122,7 @@ def shift(qkv, num_heads, head_dim):
query = shift(query, self.num_attention_heads, self.head_size).contiguous()
key = shift(key, self.num_attention_heads, self.head_size).contiguous()
value = shift(value, self.num_attention_heads, self.head_size).contiguous()

attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)

# Compute attention
Expand All @@ -136,7 +135,6 @@ def shift(qkv, num_heads, head_dim):
if self.training and not use_full:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size)
#attn_output = attn_output.transpose(1, 2)
# [bsz, q_len, nh, hd]
attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1)
attn_output = attn_output.transpose(1, 2)
Expand Down

0 comments on commit 02e4c1c

Please sign in to comment.