Skip to content

Commit

Permalink
Change gqa to use repeat instead of concatenate (ml-explore#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Feb 15, 2024
1 parent 555bbf5 commit 9edb7ed
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 40 deletions.
7 changes: 2 additions & 5 deletions llms/gguf_llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,9 @@ def __call__(
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
Expand Down
7 changes: 2 additions & 5 deletions llms/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,8 @@ def __call__(
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])

keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
Expand Down
7 changes: 2 additions & 5 deletions llms/mixtral/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,8 @@ def __call__(
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])

keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
Expand Down
7 changes: 2 additions & 5 deletions llms/mlx_lm/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,9 @@ def __call__(
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
Expand Down
7 changes: 2 additions & 5 deletions llms/mlx_lm/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,9 @@ def __call__(
0, 2, 1, 3
)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
Expand Down
7 changes: 2 additions & 5 deletions llms/mlx_lm/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,9 @@ def __call__(self, x, mask=None, cache=None):
B, L, self.num_key_value_heads, self.head_dim
).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
Expand Down
7 changes: 2 additions & 5 deletions llms/mlx_lm/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,9 @@ def __call__(
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

if cache is not None:
key_cache, value_cache = cache
Expand Down
7 changes: 2 additions & 5 deletions llms/mlx_lm/models/stablelm_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,9 @@ def __call__(self, x, mask=None, cache=None):
B, L, self.num_key_value_heads, self.head_dim
).transpose(0, 2, 1, 3)

def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_heads, L, -1])

if self.repeats > 1:
keys, values = map(repeat, (keys, values))
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)

# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
Expand Down

0 comments on commit 9edb7ed

Please sign in to comment.