Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations for mamba1 #1213

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions llms/mlx_lm/models/helium.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright © 2025 Apple Inc.

from dataclasses import dataclass
from typing import Any, Optional, Tuple

Expand Down
71 changes: 44 additions & 27 deletions llms/mlx_lm/models/mamba.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc.
# Copyright © 2024-2025 Apple Inc.

import math
from dataclasses import dataclass
Expand Down Expand Up @@ -123,47 +123,64 @@ def __init__(self, args: ModelArgs):
self.intermediate_size, self.hidden_size, bias=args.use_bias
)

def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
def ssm_step(self, x, A, state=None):
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
delta, B, C = map(
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(
deltaBC,
[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size
],
axis=-1
)
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = mx.einsum('bsd,sd->bs', new_state, C)
y = y + D * x
return y, new_state

def __call__(self, x, cache):
def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]

xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1)
x_t, z_t = xz.split(indices_or_sections=2, axis=-1)

conv_out, new_conv_cache = self.conv1d(x_t, conv_cache)
x_t = nn.silu(conv_out)

A = -mx.exp(self.A_log)

outputs = []
current_state = state_cache
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
y_t, current_state = self.ssm_step(x_t[:, t], A, current_state)
z_curr = nn.silu(z_t[:, t])
output_t = self.out_proj(y_t * z_curr)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)

return mx.stack(outputs, axis=1), (new_conv_cache, current_state)

def __call__(self, x, cache):
if cache is None or isinstance(cache, list):
conv_cache, state_cache = cache if cache is not None else (None, None)
else:
conv_cache, state_cache = cache.state

output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)

if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache

return output


Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/minicpm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023-2025 Apple Inc.

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
Expand Down