From e43ac7c90e349c340cfbc9e2b2bc07be5527c5be Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 20 Jan 2025 18:37:58 +0100 Subject: [PATCH 1/6] added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec --- llms/mlx_lm/models/mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f24146602..70ac70a3e 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -138,10 +138,10 @@ def ssm_step(self, x, state=None): if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) - new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) + new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B) if state is not None: new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) - y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) + y = mx.einsum('bsd,sd->bs', new_state, C) y = y + D * x return y, new_state From 9494a275ac1c7dc8f1eda9c1db23e98976da88a4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 20 Jan 2025 18:39:22 +0100 Subject: [PATCH 2/6] Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec --- llms/mlx_lm/models/mamba.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 70ac70a3e..b7eff756b 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -127,14 +127,10 @@ def ssm_step(self, x, state=None): A = -mx.exp(self.A_log) D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split( - deltaBC, - indices_or_sections=[ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size, - ], - axis=-1, - ) + delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split(deltaBC, [self.time_step_rank, + self.time_step_rank + self.ssm_state_size], + axis=-1)) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) From db582e4f9e1a69d7dbc6ced903d661fe420f5fd1 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 20 Jan 2025 18:42:39 +0100 Subject: [PATCH 3/6] Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec --- llms/mlx_lm/models/mamba.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index b7eff756b..5c09c9994 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -123,14 +123,20 @@ def __init__(self, args: ModelArgs): self.intermediate_size, self.hidden_size, bias=args.use_bias ) - def ssm_step(self, x, state=None): - A = -mx.exp(self.A_log) + def ssm_step(self, x, A, state=None): D = self.D deltaBC = self.x_proj(x) - delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x, - mx.split(deltaBC, [self.time_step_rank, - self.time_step_rank + self.ssm_state_size], - axis=-1)) + delta, B, C = map( + self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split( + deltaBC, + [ + self.time_step_rank, + self.time_step_rank + self.ssm_state_size + ], + axis=-1 + ) + ) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) @@ -143,6 +149,9 @@ def ssm_step(self, x, state=None): def __call__(self, x, cache): B, T, D = x.shape + + A = -mx.exp(self.A_log) + if cache is None: cache = [None, None] @@ -154,7 +163,7 @@ def __call__(self, x, cache): conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) x_t = conv_out.squeeze(1) x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, cache[1]) + y_t, cache[1] = self.ssm_step(x_t, A, cache[1]) z_t = nn.silu(z_t) output_t = y_t * z_t output_t = self.out_proj(output_t) From dfd51f16d6f0883f74cfbb1c49950d90903dd1c2 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 20 Jan 2025 18:59:16 +0100 Subject: [PATCH 4/6] Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec. --- llms/mlx_lm/models/mamba.py | 52 +++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 5c09c9994..f8db469ce 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -147,28 +147,48 @@ def ssm_step(self, x, A, state=None): y = y + D * x return y, new_state - def __call__(self, x, cache): + def _process_sequence(self, x, conv_cache, state_cache): + """Process a sequence of inputs with cached states""" B, T, D = x.shape + # Project all tokens at once + xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1) + x_t, z_t = xz.split(indices_or_sections=2, axis=-1) # Fixed: using split instead of chunk + + # Handle convolution with cache + conv_out, new_conv_cache = self.conv1d(x_t, conv_cache) + x_t = nn.silu(conv_out) + # Pre-compute A matrix A = -mx.exp(self.A_log) - - if cache is None: - cache = [None, None] - + + # Process sequence with state outputs = [] + current_state = state_cache for t in range(T): - xt = x[:, t, :] - xz = self.in_proj(xt) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) - conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) - x_t = conv_out.squeeze(1) - x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, A, cache[1]) - z_t = nn.silu(z_t) - output_t = y_t * z_t - output_t = self.out_proj(output_t) + y_t, current_state = self.ssm_step(x_t[:, t], A, current_state) + z_curr = nn.silu(z_t[:, t]) + output_t = self.out_proj(y_t * z_curr) outputs.append(output_t) - output = mx.stack(outputs, axis=1) + + return mx.stack(outputs, axis=1), (new_conv_cache, current_state) + + def __call__(self, x, cache): + if cache is None or isinstance(cache, list): + # Handle legacy cache format + conv_cache, state_cache = cache if cache is not None else (None, None) + else: + # Handle MambaCache object + conv_cache, state_cache = cache.state + + output, (new_conv_cache, new_state_cache) = self._process_sequence( + x, conv_cache, state_cache + ) + + # Update cache + if isinstance(cache, MambaCache): + cache[0] = new_conv_cache + cache[1] = new_state_cache + return output From 0d4f2c4dc0e35eebf051c885043c5303cd37a6fa Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 28 Jan 2025 21:02:50 +0100 Subject: [PATCH 5/6] cleaning up and adding apple copyright to helium modelfile --- llms/mlx_lm/models/helium.py | 2 ++ llms/mlx_lm/models/mamba.py | 10 +--------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 6ca46a725..993ce9d51 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + from dataclasses import dataclass from typing import Any, Optional, Tuple diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f8db469ce..e3877e194 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -148,20 +148,15 @@ def ssm_step(self, x, A, state=None): return y, new_state def _process_sequence(self, x, conv_cache, state_cache): - """Process a sequence of inputs with cached states""" B, T, D = x.shape - # Project all tokens at once xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1) - x_t, z_t = xz.split(indices_or_sections=2, axis=-1) # Fixed: using split instead of chunk + x_t, z_t = xz.split(indices_or_sections=2, axis=-1) - # Handle convolution with cache conv_out, new_conv_cache = self.conv1d(x_t, conv_cache) x_t = nn.silu(conv_out) - # Pre-compute A matrix A = -mx.exp(self.A_log) - # Process sequence with state outputs = [] current_state = state_cache for t in range(T): @@ -174,17 +169,14 @@ def _process_sequence(self, x, conv_cache, state_cache): def __call__(self, x, cache): if cache is None or isinstance(cache, list): - # Handle legacy cache format conv_cache, state_cache = cache if cache is not None else (None, None) else: - # Handle MambaCache object conv_cache, state_cache = cache.state output, (new_conv_cache, new_state_cache) = self._process_sequence( x, conv_cache, state_cache ) - # Update cache if isinstance(cache, MambaCache): cache[0] = new_conv_cache cache[1] = new_state_cache From 6a367fa31ef4ce740b89492ede2de638b86998a8 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 28 Jan 2025 21:04:03 +0100 Subject: [PATCH 6/6] update Copyright to this year --- llms/mlx_lm/models/helium.py | 2 +- llms/mlx_lm/models/mamba.py | 2 +- llms/mlx_lm/models/minicpm.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 993ce9d51..ff551bca6 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -1,4 +1,4 @@ -# Copyright © 2024 Apple Inc. +# Copyright © 2025 Apple Inc. from dataclasses import dataclass from typing import Any, Optional, Tuple diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index e3877e194..37fa2092a 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,4 +1,4 @@ -# Copyright © 2024 Apple Inc. +# Copyright © 2024-2025 Apple Inc. import math from dataclasses import dataclass diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index edddd5836..7140c5778 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,4 +1,4 @@ -# Copyright © 2023-2024 Apple Inc. +# Copyright © 2023-2025 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union